mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 15:46:47 +00:00
Merge branch 'master' into pgvectorstore-docs
This commit is contained in:
commit
ac7e73a531
@ -11,6 +11,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import toml
|
import toml
|
||||||
@ -104,7 +105,7 @@ def skip_private_members(app, what, name, obj, skip, options):
|
|||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
project = "🦜🔗 LangChain"
|
project = "🦜🔗 LangChain"
|
||||||
copyright = "2023, LangChain Inc"
|
copyright = f"{datetime.now().year}, LangChain Inc"
|
||||||
author = "LangChain, Inc"
|
author = "LangChain, Inc"
|
||||||
|
|
||||||
html_favicon = "_static/img/brand/favicon.png"
|
html_favicon = "_static/img/brand/favicon.png"
|
||||||
|
@ -36,10 +36,7 @@
|
|||||||
"pip install oracledb"
|
"pip install oracledb"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false
|
||||||
"pycharm": {
|
|
||||||
"is_executing": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -51,10 +48,7 @@
|
|||||||
"from settings import s"
|
"from settings import s"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false
|
||||||
"pycharm": {
|
|
||||||
"is_executing": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -97,16 +91,14 @@
|
|||||||
"doc_2 = doc_loader_2.load()"
|
"doc_2 = doc_loader_2.load()"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false
|
||||||
"pycharm": {
|
|
||||||
"is_executing": true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
"source": [
|
||||||
"With TLS authentication, wallet_location and wallet_password are not required."
|
"With TLS authentication, wallet_location and wallet_password are not required.\n",
|
||||||
|
"Bind variable option is provided by argument \"parameters\"."
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false
|
||||||
@ -117,6 +109,8 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"SQL_QUERY = \"select channel_id, channel_desc from sh.channels where channel_desc = :1 fetch first 5 rows only\"\n",
|
||||||
|
"\n",
|
||||||
"doc_loader_3 = OracleAutonomousDatabaseLoader(\n",
|
"doc_loader_3 = OracleAutonomousDatabaseLoader(\n",
|
||||||
" query=SQL_QUERY,\n",
|
" query=SQL_QUERY,\n",
|
||||||
" user=s.USERNAME,\n",
|
" user=s.USERNAME,\n",
|
||||||
@ -124,6 +118,7 @@
|
|||||||
" schema=s.SCHEMA,\n",
|
" schema=s.SCHEMA,\n",
|
||||||
" config_dir=s.CONFIG_DIR,\n",
|
" config_dir=s.CONFIG_DIR,\n",
|
||||||
" tns_name=s.TNS_NAME,\n",
|
" tns_name=s.TNS_NAME,\n",
|
||||||
|
" parameters=[\"Direct Sales\"],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"doc_3 = doc_loader_3.load()\n",
|
"doc_3 = doc_loader_3.load()\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -133,6 +128,7 @@
|
|||||||
" password=s.PASSWORD,\n",
|
" password=s.PASSWORD,\n",
|
||||||
" schema=s.SCHEMA,\n",
|
" schema=s.SCHEMA,\n",
|
||||||
" connection_string=s.CONNECTION_STRING,\n",
|
" connection_string=s.CONNECTION_STRING,\n",
|
||||||
|
" parameters=[\"Direct Sales\"],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"doc_4 = doc_loader_4.load()"
|
"doc_4 = doc_loader_4.load()"
|
||||||
],
|
],
|
||||||
|
@ -183,7 +183,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Initalize simple_datasource_qa for querying Tableau Datasources through VDS\n",
|
"# Initialize simple_datasource_qa for querying Tableau Datasources through VDS\n",
|
||||||
"analyze_datasource = initialize_simple_datasource_qa(\n",
|
"analyze_datasource = initialize_simple_datasource_qa(\n",
|
||||||
" domain=tableau_server,\n",
|
" domain=tableau_server,\n",
|
||||||
" site=tableau_site,\n",
|
" site=tableau_site,\n",
|
||||||
|
@ -10,6 +10,38 @@ from langchain_core.messages import AIMessage
|
|||||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
MODEL_COST_PER_1K_TOKENS = {
|
MODEL_COST_PER_1K_TOKENS = {
|
||||||
|
# GPT-4.1 input
|
||||||
|
"gpt-4.1": 0.002,
|
||||||
|
"gpt-4.1-2025-04-14": 0.002,
|
||||||
|
"gpt-4.1-cached": 0.0005,
|
||||||
|
"gpt-4.1-2025-04-14-cached": 0.0005,
|
||||||
|
# GPT-4.1 output
|
||||||
|
"gpt-4.1-completion": 0.008,
|
||||||
|
"gpt-4.1-2025-04-14-completion": 0.008,
|
||||||
|
# GPT-4.1-mini input
|
||||||
|
"gpt-4.1-mini": 0.0004,
|
||||||
|
"gpt-4.1-mini-2025-04-14": 0.0004,
|
||||||
|
"gpt-4.1-mini-cached": 0.0001,
|
||||||
|
"gpt-4.1-mini-2025-04-14-cached": 0.0001,
|
||||||
|
# GPT-4.1-mini output
|
||||||
|
"gpt-4.1-mini-completion": 0.0016,
|
||||||
|
"gpt-4.1-mini-2025-04-14-completion": 0.0016,
|
||||||
|
# GPT-4.1-nano input
|
||||||
|
"gpt-4.1-nano": 0.0001,
|
||||||
|
"gpt-4.1-nano-2025-04-14": 0.0001,
|
||||||
|
"gpt-4.1-nano-cached": 0.000025,
|
||||||
|
"gpt-4.1-nano-2025-04-14-cached": 0.000025,
|
||||||
|
# GPT-4.1-nano output
|
||||||
|
"gpt-4.1-nano-completion": 0.0004,
|
||||||
|
"gpt-4.1-nano-2025-04-14-completion": 0.0004,
|
||||||
|
# GPT-4.5-preview input
|
||||||
|
"gpt-4.5-preview": 0.075,
|
||||||
|
"gpt-4.5-preview-2025-02-27": 0.075,
|
||||||
|
"gpt-4.5-preview-cached": 0.0375,
|
||||||
|
"gpt-4.5-preview-2025-02-27-cached": 0.0375,
|
||||||
|
# GPT-4.5-preview output
|
||||||
|
"gpt-4.5-preview-completion": 0.15,
|
||||||
|
"gpt-4.5-preview-2025-02-27-completion": 0.15,
|
||||||
# OpenAI o1 input
|
# OpenAI o1 input
|
||||||
"o1": 0.015,
|
"o1": 0.015,
|
||||||
"o1-2024-12-17": 0.015,
|
"o1-2024-12-17": 0.015,
|
||||||
@ -18,6 +50,28 @@ MODEL_COST_PER_1K_TOKENS = {
|
|||||||
# OpenAI o1 output
|
# OpenAI o1 output
|
||||||
"o1-completion": 0.06,
|
"o1-completion": 0.06,
|
||||||
"o1-2024-12-17-completion": 0.06,
|
"o1-2024-12-17-completion": 0.06,
|
||||||
|
# OpenAI o1-pro input
|
||||||
|
"o1-pro": 0.15,
|
||||||
|
"o1-pro-2025-03-19": 0.15,
|
||||||
|
# OpenAI o1-pro output
|
||||||
|
"o1-pro-completion": 0.6,
|
||||||
|
"o1-pro-2025-03-19-completion": 0.6,
|
||||||
|
# OpenAI o3 input
|
||||||
|
"o3": 0.01,
|
||||||
|
"o3-2025-04-16": 0.01,
|
||||||
|
"o3-cached": 0.0025,
|
||||||
|
"o3-2025-04-16-cached": 0.0025,
|
||||||
|
# OpenAI o3 output
|
||||||
|
"o3-completion": 0.04,
|
||||||
|
"o3-2025-04-16-completion": 0.04,
|
||||||
|
# OpenAI o4-mini input
|
||||||
|
"o4-mini": 0.0011,
|
||||||
|
"o4-mini-2025-04-16": 0.0011,
|
||||||
|
"o4-mini-cached": 0.000275,
|
||||||
|
"o4-mini-2025-04-16-cached": 0.000275,
|
||||||
|
# OpenAI o4-mini output
|
||||||
|
"o4-mini-completion": 0.0044,
|
||||||
|
"o4-mini-2025-04-16-completion": 0.0044,
|
||||||
# OpenAI o3-mini input
|
# OpenAI o3-mini input
|
||||||
"o3-mini": 0.0011,
|
"o3-mini": 0.0011,
|
||||||
"o3-mini-2025-01-31": 0.0011,
|
"o3-mini-2025-01-31": 0.0011,
|
||||||
@ -26,6 +80,14 @@ MODEL_COST_PER_1K_TOKENS = {
|
|||||||
# OpenAI o3-mini output
|
# OpenAI o3-mini output
|
||||||
"o3-mini-completion": 0.0044,
|
"o3-mini-completion": 0.0044,
|
||||||
"o3-mini-2025-01-31-completion": 0.0044,
|
"o3-mini-2025-01-31-completion": 0.0044,
|
||||||
|
# OpenAI o1-mini input (updated pricing)
|
||||||
|
"o1-mini": 0.0011,
|
||||||
|
"o1-mini-cached": 0.00055,
|
||||||
|
"o1-mini-2024-09-12": 0.0011,
|
||||||
|
"o1-mini-2024-09-12-cached": 0.00055,
|
||||||
|
# OpenAI o1-mini output (updated pricing)
|
||||||
|
"o1-mini-completion": 0.0044,
|
||||||
|
"o1-mini-2024-09-12-completion": 0.0044,
|
||||||
# OpenAI o1-preview input
|
# OpenAI o1-preview input
|
||||||
"o1-preview": 0.015,
|
"o1-preview": 0.015,
|
||||||
"o1-preview-cached": 0.0075,
|
"o1-preview-cached": 0.0075,
|
||||||
@ -34,22 +96,6 @@ MODEL_COST_PER_1K_TOKENS = {
|
|||||||
# OpenAI o1-preview output
|
# OpenAI o1-preview output
|
||||||
"o1-preview-completion": 0.06,
|
"o1-preview-completion": 0.06,
|
||||||
"o1-preview-2024-09-12-completion": 0.06,
|
"o1-preview-2024-09-12-completion": 0.06,
|
||||||
# OpenAI o1-mini input
|
|
||||||
"o1-mini": 0.003,
|
|
||||||
"o1-mini-cached": 0.0015,
|
|
||||||
"o1-mini-2024-09-12": 0.003,
|
|
||||||
"o1-mini-2024-09-12-cached": 0.0015,
|
|
||||||
# OpenAI o1-mini output
|
|
||||||
"o1-mini-completion": 0.012,
|
|
||||||
"o1-mini-2024-09-12-completion": 0.012,
|
|
||||||
# GPT-4o-mini input
|
|
||||||
"gpt-4o-mini": 0.00015,
|
|
||||||
"gpt-4o-mini-cached": 0.000075,
|
|
||||||
"gpt-4o-mini-2024-07-18": 0.00015,
|
|
||||||
"gpt-4o-mini-2024-07-18-cached": 0.000075,
|
|
||||||
# GPT-4o-mini output
|
|
||||||
"gpt-4o-mini-completion": 0.0006,
|
|
||||||
"gpt-4o-mini-2024-07-18-completion": 0.0006,
|
|
||||||
# GPT-4o input
|
# GPT-4o input
|
||||||
"gpt-4o": 0.0025,
|
"gpt-4o": 0.0025,
|
||||||
"gpt-4o-cached": 0.00125,
|
"gpt-4o-cached": 0.00125,
|
||||||
@ -63,6 +109,65 @@ MODEL_COST_PER_1K_TOKENS = {
|
|||||||
"gpt-4o-2024-05-13-completion": 0.015,
|
"gpt-4o-2024-05-13-completion": 0.015,
|
||||||
"gpt-4o-2024-08-06-completion": 0.01,
|
"gpt-4o-2024-08-06-completion": 0.01,
|
||||||
"gpt-4o-2024-11-20-completion": 0.01,
|
"gpt-4o-2024-11-20-completion": 0.01,
|
||||||
|
# GPT-4o-audio-preview input
|
||||||
|
"gpt-4o-audio-preview": 0.0025,
|
||||||
|
"gpt-4o-audio-preview-2024-12-17": 0.0025,
|
||||||
|
"gpt-4o-audio-preview-2024-10-01": 0.0025,
|
||||||
|
# GPT-4o-audio-preview output
|
||||||
|
"gpt-4o-audio-preview-completion": 0.01,
|
||||||
|
"gpt-4o-audio-preview-2024-12-17-completion": 0.01,
|
||||||
|
"gpt-4o-audio-preview-2024-10-01-completion": 0.01,
|
||||||
|
# GPT-4o-realtime-preview input
|
||||||
|
"gpt-4o-realtime-preview": 0.005,
|
||||||
|
"gpt-4o-realtime-preview-2024-12-17": 0.005,
|
||||||
|
"gpt-4o-realtime-preview-2024-10-01": 0.005,
|
||||||
|
"gpt-4o-realtime-preview-cached": 0.0025,
|
||||||
|
"gpt-4o-realtime-preview-2024-12-17-cached": 0.0025,
|
||||||
|
"gpt-4o-realtime-preview-2024-10-01-cached": 0.0025,
|
||||||
|
# GPT-4o-realtime-preview output
|
||||||
|
"gpt-4o-realtime-preview-completion": 0.02,
|
||||||
|
"gpt-4o-realtime-preview-2024-12-17-completion": 0.02,
|
||||||
|
"gpt-4o-realtime-preview-2024-10-01-completion": 0.02,
|
||||||
|
# GPT-4o-mini input
|
||||||
|
"gpt-4o-mini": 0.00015,
|
||||||
|
"gpt-4o-mini-cached": 0.000075,
|
||||||
|
"gpt-4o-mini-2024-07-18": 0.00015,
|
||||||
|
"gpt-4o-mini-2024-07-18-cached": 0.000075,
|
||||||
|
# GPT-4o-mini output
|
||||||
|
"gpt-4o-mini-completion": 0.0006,
|
||||||
|
"gpt-4o-mini-2024-07-18-completion": 0.0006,
|
||||||
|
# GPT-4o-mini-audio-preview input
|
||||||
|
"gpt-4o-mini-audio-preview": 0.00015,
|
||||||
|
"gpt-4o-mini-audio-preview-2024-12-17": 0.00015,
|
||||||
|
# GPT-4o-mini-audio-preview output
|
||||||
|
"gpt-4o-mini-audio-preview-completion": 0.0006,
|
||||||
|
"gpt-4o-mini-audio-preview-2024-12-17-completion": 0.0006,
|
||||||
|
# GPT-4o-mini-realtime-preview input
|
||||||
|
"gpt-4o-mini-realtime-preview": 0.0006,
|
||||||
|
"gpt-4o-mini-realtime-preview-2024-12-17": 0.0006,
|
||||||
|
"gpt-4o-mini-realtime-preview-cached": 0.0003,
|
||||||
|
"gpt-4o-mini-realtime-preview-2024-12-17-cached": 0.0003,
|
||||||
|
# GPT-4o-mini-realtime-preview output
|
||||||
|
"gpt-4o-mini-realtime-preview-completion": 0.0024,
|
||||||
|
"gpt-4o-mini-realtime-preview-2024-12-17-completion": 0.0024,
|
||||||
|
# GPT-4o-mini-search-preview input
|
||||||
|
"gpt-4o-mini-search-preview": 0.00015,
|
||||||
|
"gpt-4o-mini-search-preview-2025-03-11": 0.00015,
|
||||||
|
# GPT-4o-mini-search-preview output
|
||||||
|
"gpt-4o-mini-search-preview-completion": 0.0006,
|
||||||
|
"gpt-4o-mini-search-preview-2025-03-11-completion": 0.0006,
|
||||||
|
# GPT-4o-search-preview input
|
||||||
|
"gpt-4o-search-preview": 0.0025,
|
||||||
|
"gpt-4o-search-preview-2025-03-11": 0.0025,
|
||||||
|
# GPT-4o-search-preview output
|
||||||
|
"gpt-4o-search-preview-completion": 0.01,
|
||||||
|
"gpt-4o-search-preview-2025-03-11-completion": 0.01,
|
||||||
|
# Computer-use-preview input
|
||||||
|
"computer-use-preview": 0.003,
|
||||||
|
"computer-use-preview-2025-03-11": 0.003,
|
||||||
|
# Computer-use-preview output
|
||||||
|
"computer-use-preview-completion": 0.012,
|
||||||
|
"computer-use-preview-2025-03-11-completion": 0.012,
|
||||||
# GPT-4 input
|
# GPT-4 input
|
||||||
"gpt-4": 0.03,
|
"gpt-4": 0.03,
|
||||||
"gpt-4-0314": 0.03,
|
"gpt-4-0314": 0.03,
|
||||||
@ -219,6 +324,7 @@ def standardize_model_name(
|
|||||||
or model_name.startswith("gpt-35")
|
or model_name.startswith("gpt-35")
|
||||||
or model_name.startswith("o1-")
|
or model_name.startswith("o1-")
|
||||||
or model_name.startswith("o3-")
|
or model_name.startswith("o3-")
|
||||||
|
or model_name.startswith("o4-")
|
||||||
or ("finetuned" in model_name and "legacy" not in model_name)
|
or ("finetuned" in model_name and "legacy" not in model_name)
|
||||||
):
|
):
|
||||||
return model_name + "-completion"
|
return model_name + "-completion"
|
||||||
@ -226,8 +332,10 @@ def standardize_model_name(
|
|||||||
token_type == TokenType.PROMPT_CACHED
|
token_type == TokenType.PROMPT_CACHED
|
||||||
and (
|
and (
|
||||||
model_name.startswith("gpt-4o")
|
model_name.startswith("gpt-4o")
|
||||||
|
or model_name.startswith("gpt-4.1")
|
||||||
or model_name.startswith("o1")
|
or model_name.startswith("o1")
|
||||||
or model_name.startswith("o3")
|
or model_name.startswith("o3")
|
||||||
|
or model_name.startswith("o4")
|
||||||
)
|
)
|
||||||
and not (model_name.startswith("gpt-4o-2024-05-13"))
|
and not (model_name.startswith("gpt-4o-2024-05-13"))
|
||||||
):
|
):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -31,6 +31,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
|||||||
wallet_password: Optional[str] = None,
|
wallet_password: Optional[str] = None,
|
||||||
connection_string: Optional[str] = None,
|
connection_string: Optional[str] = None,
|
||||||
metadata: Optional[List[str]] = None,
|
metadata: Optional[List[str]] = None,
|
||||||
|
parameters: Optional[Union[list, tuple, dict]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
init method
|
init method
|
||||||
@ -44,6 +45,7 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
|||||||
:param wallet_password: password of wallet
|
:param wallet_password: password of wallet
|
||||||
:param connection_string: connection string to connect to adb instance
|
:param connection_string: connection string to connect to adb instance
|
||||||
:param metadata: metadata used in document
|
:param metadata: metadata used in document
|
||||||
|
:param parameters: bind variable to use in query
|
||||||
"""
|
"""
|
||||||
# Mandatory required arguments.
|
# Mandatory required arguments.
|
||||||
self.query = query
|
self.query = query
|
||||||
@ -67,6 +69,9 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
|||||||
# metadata column
|
# metadata column
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
|
# parameters, e.g bind variable
|
||||||
|
self.parameters = parameters
|
||||||
|
|
||||||
# dsn
|
# dsn
|
||||||
self.dsn: Optional[str]
|
self.dsn: Optional[str]
|
||||||
self._set_dsn()
|
self._set_dsn()
|
||||||
@ -96,7 +101,10 @@ class OracleAutonomousDatabaseLoader(BaseLoader):
|
|||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
if self.schema:
|
if self.schema:
|
||||||
cursor.execute(f"alter session set current_schema={self.schema}")
|
cursor.execute(f"alter session set current_schema={self.schema}")
|
||||||
cursor.execute(self.query)
|
if self.parameters:
|
||||||
|
cursor.execute(self.query, self.parameters)
|
||||||
|
else:
|
||||||
|
cursor.execute(self.query)
|
||||||
columns = [col[0] for col in cursor.description]
|
columns = [col[0] for col in cursor.description]
|
||||||
data = cursor.fetchall()
|
data = cursor.fetchall()
|
||||||
data = [
|
data = [
|
||||||
|
@ -668,7 +668,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return self._get_len_safe_embeddings(texts, engine=engine)
|
return self._get_len_safe_embeddings(
|
||||||
|
texts, engine=engine, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
async def aembed_documents(
|
async def aembed_documents(
|
||||||
self, texts: List[str], chunk_size: Optional[int] = 0
|
self, texts: List[str], chunk_size: Optional[int] = 0
|
||||||
@ -686,7 +688,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# NOTE: to keep things simple, we assume the list may contain texts longer
|
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||||
# than the maximum context and use length-safe embedding function.
|
# than the maximum context and use length-safe embedding function.
|
||||||
engine = cast(str, self.deployment)
|
engine = cast(str, self.deployment)
|
||||||
return await self._aget_len_safe_embeddings(texts, engine=engine)
|
return self._get_len_safe_embeddings(
|
||||||
|
texts, engine=engine, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
from langchain_community.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_openai_invalid_model_kwargs() -> None:
|
def test_openai_invalid_model_kwargs() -> None:
|
||||||
@ -14,3 +19,20 @@ def test_openai_incorrect_field() -> None:
|
|||||||
with pytest.warns(match="not default parameter"):
|
with pytest.warns(match="not default parameter"):
|
||||||
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
|
llm = OpenAIEmbeddings(foo="bar", openai_api_key="foo") # type: ignore[call-arg]
|
||||||
assert llm.model_kwargs == {"foo": "bar"}
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_embed_documents_with_custom_chunk_size() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings(chunk_size=2)
|
||||||
|
texts = ["text1", "text2", "text3", "text4"]
|
||||||
|
custom_chunk_size = 3
|
||||||
|
|
||||||
|
with patch.object(embeddings.client, "create") as mock_create:
|
||||||
|
mock_create.side_effect = [
|
||||||
|
{"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]},
|
||||||
|
{"data": [{"embedding": [0.5, 0.6]}, {"embedding": [0.7, 0.8]}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
|
||||||
|
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||||
|
mock_create.assert_any_call(input=[[1342, 19]], **embeddings._invocation_params)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Annotated, Any, Optional
|
from typing import Annotated, Any, Optional
|
||||||
|
|
||||||
@ -16,6 +17,8 @@ from langchain_core.outputs import ChatGeneration, Generation
|
|||||||
from langchain_core.utils.json import parse_partial_json
|
from langchain_core.utils.json import parse_partial_json
|
||||||
from langchain_core.utils.pydantic import TypeBaseModel
|
from langchain_core.utils.pydantic import TypeBaseModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_call(
|
def parse_tool_call(
|
||||||
raw_tool_call: dict[str, Any],
|
raw_tool_call: dict[str, Any],
|
||||||
@ -250,6 +253,14 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
return parsed_result
|
return parsed_result
|
||||||
|
|
||||||
|
|
||||||
|
# Common cause of ValidationError is truncated output due to max_tokens.
|
||||||
|
_MAX_TOKENS_ERROR = (
|
||||||
|
"Output parser received a `max_tokens` stop reason. "
|
||||||
|
"The output is likely incomplete—please increase `max_tokens` "
|
||||||
|
"or shorten your prompt."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PydanticToolsParser(JsonOutputToolsParser):
|
class PydanticToolsParser(JsonOutputToolsParser):
|
||||||
"""Parse tools from OpenAI response."""
|
"""Parse tools from OpenAI response."""
|
||||||
|
|
||||||
@ -296,6 +307,14 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
|||||||
except (ValidationError, ValueError):
|
except (ValidationError, ValueError):
|
||||||
if partial:
|
if partial:
|
||||||
continue
|
continue
|
||||||
|
has_max_tokens_stop_reason = any(
|
||||||
|
generation.message.response_metadata.get("stop_reason")
|
||||||
|
== "max_tokens"
|
||||||
|
for generation in result
|
||||||
|
if isinstance(generation, ChatGeneration)
|
||||||
|
)
|
||||||
|
if has_max_tokens_stop_reason:
|
||||||
|
logger.exception(_MAX_TOKENS_ERROR)
|
||||||
raise
|
raise
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
return pydantic_objects[0] if pydantic_objects else None
|
return pydantic_objects[0] if pydantic_objects else None
|
||||||
|
@ -108,7 +108,7 @@ class Node(NamedTuple):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
data: Union[type[BaseModel], RunnableType]
|
data: Union[type[BaseModel], RunnableType, None]
|
||||||
metadata: Optional[dict[str, Any]]
|
metadata: Optional[dict[str, Any]]
|
||||||
|
|
||||||
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
def copy(self, *, id: Optional[str] = None, name: Optional[str] = None) -> Node:
|
||||||
@ -181,7 +181,7 @@ class MermaidDrawMethod(Enum):
|
|||||||
API = "api" # Uses Mermaid.INK API to render the graph
|
API = "api" # Uses Mermaid.INK API to render the graph
|
||||||
|
|
||||||
|
|
||||||
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
def node_data_str(id: str, data: Union[type[BaseModel], RunnableType, None]) -> str:
|
||||||
"""Convert the data of a node to a string.
|
"""Convert the data of a node to a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -193,7 +193,7 @@ def node_data_str(id: str, data: Union[type[BaseModel], RunnableType]) -> str:
|
|||||||
"""
|
"""
|
||||||
from langchain_core.runnables.base import Runnable
|
from langchain_core.runnables.base import Runnable
|
||||||
|
|
||||||
if not is_uuid(id):
|
if not is_uuid(id) or data is None:
|
||||||
return id
|
return id
|
||||||
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
data_str = data.get_name() if isinstance(data, Runnable) else data.__name__
|
||||||
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
return data_str if not data_str.startswith("Runnable") else data_str[8:]
|
||||||
@ -215,8 +215,10 @@ def node_data_json(
|
|||||||
from langchain_core.load.serializable import to_json_not_implemented
|
from langchain_core.load.serializable import to_json_not_implemented
|
||||||
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
from langchain_core.runnables.base import Runnable, RunnableSerializable
|
||||||
|
|
||||||
if isinstance(node.data, RunnableSerializable):
|
if node.data is None:
|
||||||
json: dict[str, Any] = {
|
json: dict[str, Any] = {}
|
||||||
|
elif isinstance(node.data, RunnableSerializable):
|
||||||
|
json = {
|
||||||
"type": "runnable",
|
"type": "runnable",
|
||||||
"data": {
|
"data": {
|
||||||
"id": node.data.lc_id(),
|
"id": node.data.lc_id(),
|
||||||
@ -317,7 +319,7 @@ class Graph:
|
|||||||
|
|
||||||
def add_node(
|
def add_node(
|
||||||
self,
|
self,
|
||||||
data: Union[type[BaseModel], RunnableType],
|
data: Union[type[BaseModel], RunnableType, None],
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
|
@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Iterator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -635,3 +635,24 @@ def test_parse_with_different_pydantic_1_proper() -> None:
|
|||||||
forecast="Sunny",
|
forecast="Sunny",
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_tokens_error(caplog: Any) -> None:
|
||||||
|
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||||
|
input = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call_OwL7f5PE",
|
||||||
|
"name": "NameCollector",
|
||||||
|
"args": {"names": ["suz", "jerm"]},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_metadata={"stop_reason": "max_tokens"},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
_ = parser.invoke(input)
|
||||||
|
assert any(
|
||||||
|
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||||
|
for record, msg in zip(caplog.records, caplog.messages)
|
||||||
|
)
|
||||||
|
@ -82,8 +82,8 @@ class AnthropicTool(TypedDict):
|
|||||||
"""Anthropic tool definition."""
|
"""Anthropic tool definition."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
input_schema: dict[str, Any]
|
||||||
|
description: NotRequired[str]
|
||||||
cache_control: NotRequired[dict[str, str]]
|
cache_control: NotRequired[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
@ -1675,9 +1675,10 @@ def convert_to_anthropic_tool(
|
|||||||
oai_formatted = convert_to_openai_tool(tool)["function"]
|
oai_formatted = convert_to_openai_tool(tool)["function"]
|
||||||
anthropic_formatted = AnthropicTool(
|
anthropic_formatted = AnthropicTool(
|
||||||
name=oai_formatted["name"],
|
name=oai_formatted["name"],
|
||||||
description=oai_formatted["description"],
|
|
||||||
input_schema=oai_formatted["parameters"],
|
input_schema=oai_formatted["parameters"],
|
||||||
)
|
)
|
||||||
|
if "description" in oai_formatted:
|
||||||
|
anthropic_formatted["description"] = oai_formatted["description"]
|
||||||
return anthropic_formatted
|
return anthropic_formatted
|
||||||
|
|
||||||
|
|
||||||
|
@ -931,3 +931,12 @@ def test_anthropic_bind_tools_tool_choice() -> None:
|
|||||||
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
|
assert cast(RunnableBinding, chat_model_with_tools).kwargs["tool_choice"] == {
|
||||||
"type": "any"
|
"type": "any"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_optional_description() -> None:
|
||||||
|
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
|
||||||
|
|
||||||
|
class SampleModel(BaseModel):
|
||||||
|
sample_field: str
|
||||||
|
|
||||||
|
_ = llm.with_structured_output(SampleModel.model_json_schema())
|
||||||
|
Loading…
Reference in New Issue
Block a user