From 0da5078cad3a689e51cd451c7cac63ac35b5accd Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sun, 14 Jul 2024 18:11:01 -0700 Subject: [PATCH] langchain[minor]: Generic configurable model (#23419) alternative to [23244](https://github.com/langchain-ai/langchain/pull/23244). allows you to use chat model declarative methods ![Screenshot 2024-06-25 at 1 07 10 PM](https://github.com/langchain-ai/langchain/assets/22008038/910d1694-9b7b-46bc-bc2e-3792df9321d6) --- .../how_to/chat_models_universal_init.ipynb | 243 ++++++- libs/core/langchain_core/runnables/base.py | 2 +- libs/langchain/langchain/chat_models/base.py | 608 +++++++++++++++++- libs/langchain/poetry.lock | 87 ++- libs/langchain/pyproject.toml | 4 + .../integration_tests/chat_models/__init__.py | 0 .../chat_models/test_base.py | 59 ++ .../tests/unit_tests/chat_models/test_base.py | 153 ++++- .../tests/unit_tests/test_dependencies.py | 1 + 9 files changed, 1067 insertions(+), 90 deletions(-) create mode 100644 libs/langchain/tests/integration_tests/chat_models/__init__.py create mode 100644 libs/langchain/tests/integration_tests/chat_models/test_base.py diff --git a/docs/docs/how_to/chat_models_universal_init.ipynb b/docs/docs/how_to/chat_models_universal_init.ipynb index c77083cdfb1..7c304b498a4 100644 --- a/docs/docs/how_to/chat_models_universal_init.ipynb +++ b/docs/docs/how_to/chat_models_universal_init.ipynb @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install -qU langchain langchain-openai langchain-anthropic langchain-google-vertexai" + "%pip install -qU langchain>=0.2.7 langchain-openai langchain-anthropic langchain-google-vertexai" ] }, { @@ -76,32 +76,6 @@ "print(\"Gemini 1.5: \" + gemini_15.invoke(\"what's your name\").content + \"\\n\")" ] }, - { - "cell_type": "markdown", - "id": "fff9a4c8-b6ee-4a1a-8d3d-0ecaa312d4ed", - "metadata": {}, - "source": [ - "## Simple config example" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75c25d39-bf47-4b51-a6c6-64d9c572bfd6", - "metadata": {}, - "outputs": [], - "source": [ - "user_config = {\n", - " \"model\": \"...user-specified...\",\n", - " \"model_provider\": \"...user-specified...\",\n", - " \"temperature\": 0,\n", - " \"max_tokens\": 1000,\n", - "}\n", - "\n", - "llm = init_chat_model(**user_config)\n", - "llm.invoke(\"what's your name\")" - ] - }, { "cell_type": "markdown", "id": "f811f219-5e78-4b62-b495-915d52a22532", @@ -125,12 +99,215 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "da07b5c0-d2e6-42e4-bfcd-2efcfaae6221", + "cell_type": "markdown", + "id": "476a44db-c50d-4846-951d-0f1c9ba8bbaa", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "## Creating a configurable model\n", + "\n", + "You can also create a runtime-configurable model by specifying `configurable_fields`. If you don't specify a `model` value, then \"model\" and \"model_provider\" be configurable by default." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6c037f27-12d7-4e83-811e-4245c0e3ba58", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_d576307f90', 'finish_reason': 'stop', 'logprobs': None}, id='run-5428ab5c-b5c0-46de-9946-5d4ca40dbdc8-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "configurable_model = init_chat_model(temperature=0)\n", + "\n", + "configurable_model.invoke(\n", + " \"what's your name\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "321e3036-abd2-4e1f-bcc6-606efd036954", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_012XvotUJ3kGLXJUWKBVxJUi', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-1ad1eefe-f1c6-4244-8bc6-90e2cb7ee554-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "configurable_model.invoke(\n", + " \"what's your name\", config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7f3b3d4a-4066-45e4-8297-ea81ac8e70b7", + "metadata": {}, + "source": [ + "### Configurable model with default values\n", + "\n", + "We can create a configurable model with default model values, specify which parameters are configurable, and add prefixes to configurable params:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "814a2289-d0db-401e-b555-d5116112b413", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"I'm an AI language model created by OpenAI, and I don't have a personal name. You can call me Assistant or any other name you prefer! How can I assist you today?\", response_metadata={'token_usage': {'completion_tokens': 37, 'prompt_tokens': 11, 'total_tokens': 48}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_ce0793330f', 'finish_reason': 'stop', 'logprobs': None}, id='run-3923e328-7715-4cd6-b215-98e4b6bf7c9d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 37, 'total_tokens': 48})" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_llm = init_chat_model(\n", + " model=\"gpt-4o\",\n", + " temperature=0,\n", + " configurable_fields=(\"model\", \"model_provider\", \"temperature\", \"max_tokens\"),\n", + " config_prefix=\"first\", # useful when you have a chain with multiple models\n", + ")\n", + "\n", + "first_llm.invoke(\"what's your name\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "6c8755ba-c001-4f5a-a497-be3f1db83244", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"My name is Claude. It's nice to meet you!\", response_metadata={'id': 'msg_01RyYR64DoMPNCfHeNnroMXm', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 11, 'output_tokens': 15}}, id='run-22446159-3723-43e6-88df-b84797e7751d-0', usage_metadata={'input_tokens': 11, 'output_tokens': 15, 'total_tokens': 26})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "first_llm.invoke(\n", + " \"what's your name\",\n", + " config={\n", + " \"configurable\": {\n", + " \"first_model\": \"claude-3-5-sonnet-20240620\",\n", + " \"first_temperature\": 0.5,\n", + " \"first_max_tokens\": 100,\n", + " }\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0072b1a3-7e44-4b4e-8b07-efe1ba91a689", + "metadata": {}, + "source": [ + "### Using a configurable model declaratively\n", + "\n", + "We can call declarative operations like `bind_tools`, `with_structured_output`, `with_configurable`, etc. on a configurable model and chain a configurable model in the same way that we would a regularly instantiated chat model object." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "067dabee-1050-4110-ae24-c48eba01e13b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetPopulation',\n", + " 'args': {'location': 'Los Angeles, CA'},\n", + " 'id': 'call_sYT3PFMufHGWJD32Hi2CTNUP'},\n", + " {'name': 'GetPopulation',\n", + " 'args': {'location': 'New York, NY'},\n", + " 'id': 'call_j1qjhxRnD3ffQmRyqjlI1Lnk'}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", + "class GetWeather(BaseModel):\n", + " \"\"\"Get the current weather in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "class GetPopulation(BaseModel):\n", + " \"\"\"Get the current population in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "llm = init_chat_model(temperature=0)\n", + "llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])\n", + "\n", + "llm_with_tools.invoke(\n", + " \"what's bigger in 2024 LA or NYC\", config={\"configurable\": {\"model\": \"gpt-4o\"}}\n", + ").tool_calls" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e57dfe9f-cd24-4e37-9ce9-ccf8daf78f89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetPopulation',\n", + " 'args': {'location': 'Los Angeles, CA'},\n", + " 'id': 'toolu_01CxEHxKtVbLBrvzFS7GQ5xR'},\n", + " {'name': 'GetPopulation',\n", + " 'args': {'location': 'New York City, NY'},\n", + " 'id': 'toolu_013A79qt5toWSsKunFBDZd5S'}]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_with_tools.invoke(\n", + " \"what's bigger in 2024 LA or NYC\",\n", + " config={\"configurable\": {\"model\": \"claude-3-5-sonnet-20240620\"}},\n", + ").tool_calls" + ] } ], "metadata": { @@ -149,7 +326,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index a410672a07b..3b874ca4ff6 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1327,7 +1327,7 @@ class Runnable(Generic[Input, Output], ABC): def with_config( self, config: Optional[RunnableConfig] = None, - # Sadly Unpack is not well supported by mypy so this will have to be untyped + # Sadly Unpack is not well-supported by mypy so this will have to be untyped **kwargs: Any, ) -> Runnable[Input, Output]: """ diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index c1fb43ec96d..ca9a83d876d 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -1,30 +1,99 @@ +from __future__ import annotations + +import warnings from importlib import util -from typing import Any, Optional +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) from langchain_core._api import beta -from langchain_core.language_models.chat_models import ( +from langchain_core.language_models import ( BaseChatModel, + LanguageModelInput, SimpleChatModel, +) +from langchain_core.language_models.chat_models import ( agenerate_from_stream, generate_from_stream, ) +from langchain_core.messages import AnyMessage, BaseMessage +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables.schema import StreamEvent +from langchain_core.tools import BaseTool +from langchain_core.tracers import RunLog, RunLogPatch +from typing_extensions import TypeAlias __all__ = [ + "init_chat_model", + # For backwards compatibility "BaseChatModel", "SimpleChatModel", "generate_from_stream", "agenerate_from_stream", - "init_chat_model", ] +@overload +def init_chat_model( # type: ignore[overload-overlap] + model: str, + *, + model_provider: Optional[str] = None, + configurable_fields: Literal[None] = None, + config_prefix: Optional[str] = None, + **kwargs: Any, +) -> BaseChatModel: ... + + +@overload +def init_chat_model( + model: Literal[None] = None, + *, + model_provider: Optional[str] = None, + configurable_fields: Literal[None] = None, + config_prefix: Optional[str] = None, + **kwargs: Any, +) -> _ConfigurableModel: ... + + +@overload +def init_chat_model( + model: Optional[str] = None, + *, + model_provider: Optional[str] = None, + configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ..., + config_prefix: Optional[str] = None, + **kwargs: Any, +) -> _ConfigurableModel: ... + + # FOR CONTRIBUTORS: If adding support for a new provider, please append the provider # name to the supported list in the docstring below. Do *not* change the order of the # existing providers. @beta() def init_chat_model( - model: str, *, model_provider: Optional[str] = None, **kwargs: Any -) -> BaseChatModel: + model: Optional[str] = None, + *, + model_provider: Optional[str] = None, + configurable_fields: Optional[ + Union[Literal["any"], List[str], Tuple[str, ...]] + ] = None, + config_prefix: Optional[str] = None, + **kwargs: Any, +) -> Union[BaseChatModel, _ConfigurableModel]: """Initialize a ChatModel from the model name and provider. Must have the integration package corresponding to the model provider installed. @@ -55,19 +124,43 @@ def init_chat_model( - gemini... -> google_vertexai - command... -> cohere - accounts/fireworks... -> fireworks + configurable_fields: Which model parameters are + configurable: + - None: No configurable fields. + - "any": All fields are configurable. *See Security Note below.* + - Union[List[str], Tuple[str, ...]]: Specified fields are configurable. + + Fields are assumed to have config_prefix stripped if there is a + config_prefix. If model is specified, then defaults to None. If model is + not specified, then defaults to ``("model", "model_provider")``. + + ***Security Note***: Setting ``configurable_fields="any"`` means fields like + api_key, base_url, etc. can be altered at runtime, potentially redirecting + model requests to a different service/user. Make sure that if you're + accepting untrusted configurations that you enumerate the + ``configurable_fields=(...)`` explicitly. + + config_prefix: If config_prefix is a non-empty string then model will be + configurable at runtime via the + ``config["configurable"]["{config_prefix}_{param}"]`` keys. If + config_prefix is an empty string then model will be configurable via + ``config["configurable"]["{param}"]``. kwargs: Additional keyword args to pass to ``<>.__init__(model=model_name, **kwargs)``. Returns: - The BaseChatModel corresponding to the model_name and model_provider specified. + A BaseChatModel corresponding to the model_name and model_provider specified if + configurability is inferred to be False. If configurable, a chat model emulator + that initializes the underlying model at runtime once a config is passed in. Raises: ValueError: If model_provider cannot be inferred or isn't supported. ImportError: If the model provider integration package is not installed. - Example: + Initialize non-configurable models: .. code-block:: python + # pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai from langchain.chat_models import init_chat_model gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0) @@ -77,7 +170,125 @@ def init_chat_model( gpt_4o.invoke("what's your name") claude_opus.invoke("what's your name") gemini_15.invoke("what's your name") + + + Create a partially configurable model with no default model: + .. code-block:: python + + # pip install langchain langchain-openai langchain-anthropic + from langchain.chat_models import init_chat_model + + # We don't need to specify configurable=True if a model isn't specified. + configurable_model = init_chat_model(temperature=0) + + configurable_model.invoke( + "what's your name", + config={"configurable": {"model": "gpt-4o"}} + ) + # GPT-4o response + + configurable_model.invoke( + "what's your name", + config={"configurable": {"model": "claude-3-5-sonnet-20240620"}} + ) + # claude-3.5 sonnet response + + Create a fully configurable model with a default model and a config prefix: + .. code-block:: python + + # pip install langchain langchain-openai langchain-anthropic + from langchain.chat_models import init_chat_model + + configurable_model_with_default = init_chat_model( + "gpt-4o", + model_provider="openai", + configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime. + config_prefix="foo", + temperature=0 + ) + + configurable_model_with_default.invoke("what's your name") + # GPT-4o response with temperature 0 + + configurable_model_with_default.invoke( + "what's your name", + config={ + "configurable": { + "foo_model": "claude-3-5-sonnet-20240620", + "foo_model_provider": "anthropic", + "foo_temperature": 0.6 + } + } + ) + # Claude-3.5 sonnet response with temperature 0.6 + + Bind tools to a configurable model: + You can call any ChatModel declarative methods on a configurable model in the + same way that you would with a normal model. + + .. code-block:: python + + # pip install langchain langchain-openai langchain-anthropic + from langchain.chat_models import init_chat_model + from langchain_core.pydantic_v1 import BaseModel, Field + + class GetWeather(BaseModel): + '''Get the current weather in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + class GetPopulation(BaseModel): + '''Get the current population in a given location''' + + location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + + configurable_model = init_chat_model( + "gpt-4o", + configurable_fields=("model", "model_provider"), + temperature=0 + ) + + configurable_model_with_tools = configurable_model.bind_tools([GetWeather, GetPopulation]) + configurable_model_with_tools.invoke( + "Which city is hotter today and which is bigger: LA or NY?" + ) + # GPT-4o response with tool calls + + configurable_model_with_tools.invoke( + "Which city is hotter today and which is bigger: LA or NY?", + config={"configurable": {"model": "claude-3-5-sonnet-20240620"}} + ) + # Claude-3.5 sonnet response with tools """ # noqa: E501 + if not model and not configurable_fields: + configurable_fields = ("model", "model_provider") + config_prefix = config_prefix or "" + if config_prefix and not configurable_fields: + warnings.warn( + f"{config_prefix=} has been set but no fields are configurable. Set " + f"`configurable_fields=(...)` to specify the model params that are " + f"configurable." + ) + + if not configurable_fields: + return _init_chat_model_helper( + cast(str, model), model_provider=model_provider, **kwargs + ) + else: + if model: + kwargs["model"] = model + if model_provider: + kwargs["model_provider"] = model_provider + return _ConfigurableModel( + default_config=kwargs, + config_prefix=config_prefix, + configurable_fields=configurable_fields, + ) + + +def _init_chat_model_helper( + model: str, *, model_provider: Optional[str] = None, **kwargs: Any +) -> BaseChatModel: model_provider = model_provider or _attempt_infer_model_provider(model) if not model_provider: raise ValueError( @@ -200,3 +411,386 @@ def _check_pkg(pkg: str) -> None: f"Unable to import {pkg_kebab}. Please install with " f"`pip install -U {pkg_kebab}`" ) + + +def _remove_prefix(s: str, prefix: str) -> str: + if s.startswith(prefix): + s = s[len(prefix) :] + return s + + +_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output") + + +class _ConfigurableModel(Runnable[LanguageModelInput, Any]): + def __init__( + self, + *, + default_config: Optional[dict] = None, + configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any", + config_prefix: str = "", + queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (), + ) -> None: + self._default_config: dict = default_config or {} + self._configurable_fields: Union[Literal["any"], List[str]] = ( + configurable_fields + if configurable_fields == "any" + else list(configurable_fields) + ) + self._config_prefix = ( + config_prefix + "_" + if config_prefix and not config_prefix.endswith("_") + else config_prefix + ) + self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list( + queued_declarative_operations + ) + + def __getattr__(self, name: str) -> Any: + if name in _DECLARATIVE_METHODS: + # Declarative operations that cannot be applied until after an actual model + # object is instantiated. So instead of returning the actual operation, + # we record the operation and its arguments in a queue. This queue is + # then applied in order whenever we actually instantiate the model (in + # self._model()). + def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel: + queued_declarative_operations = list( + self._queued_declarative_operations + ) + queued_declarative_operations.append((name, args, kwargs)) + return _ConfigurableModel( + default_config=dict(self._default_config), + configurable_fields=list(self._configurable_fields) + if isinstance(self._configurable_fields, list) + else self._configurable_fields, + config_prefix=self._config_prefix, + queued_declarative_operations=queued_declarative_operations, + ) + + return queue + elif self._default_config and (model := self._model()) and hasattr(model, name): + return getattr(model, name) + else: + msg = f"{name} is not a BaseChatModel attribute" + if self._default_config: + msg += " and is not implemented on the default model" + msg += "." + raise AttributeError(msg) + + def _model(self, config: Optional[RunnableConfig] = None) -> Runnable: + params = {**self._default_config, **self._model_params(config)} + model = _init_chat_model_helper(**params) + for name, args, kwargs in self._queued_declarative_operations: + model = getattr(model, name)(*args, **kwargs) + return model + + def _model_params(self, config: Optional[RunnableConfig]) -> dict: + config = config or {} + model_params = { + _remove_prefix(k, self._config_prefix): v + for k, v in config.get("configurable", {}).items() + if k.startswith(self._config_prefix) + } + if self._configurable_fields != "any": + model_params = { + k: v for k, v in model_params.items() if k in self._configurable_fields + } + return model_params + + def with_config( + self, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> _ConfigurableModel: + """Bind config to a Runnable, returning a new Runnable.""" + config = RunnableConfig(**(config or {}), **cast(RunnableConfig, kwargs)) + model_params = self._model_params(config) + remaining_config = {k: v for k, v in config.items() if k != "configurable"} + remaining_config["configurable"] = { + k: v + for k, v in config.get("configurable", {}).items() + if _remove_prefix(k, self._config_prefix) not in model_params + } + queued_declarative_operations = list(self._queued_declarative_operations) + if remaining_config: + queued_declarative_operations.append( + ("with_config", (), {"config": remaining_config}) + ) + return _ConfigurableModel( + default_config={**self._default_config, **model_params}, + configurable_fields=list(self._configurable_fields) + if isinstance(self._configurable_fields, list) + else self._configurable_fields, + config_prefix=self._config_prefix, + queued_declarative_operations=queued_declarative_operations, + ) + + @property + def InputType(self) -> TypeAlias: + """Get the input type for this runnable.""" + from langchain_core.prompt_values import ( + ChatPromptValueConcrete, + StringPromptValue, + ) + + # This is a version of LanguageModelInput which replaces the abstract + # base class BaseMessage with a union of its subclasses, which makes + # for a much better schema. + return Union[ + str, + Union[StringPromptValue, ChatPromptValueConcrete], + List[AnyMessage], + ] + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + return self._model(config).invoke(input, config=config, **kwargs) + + async def ainvoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Any: + return await self._model(config).ainvoke(input, config=config, **kwargs) + + def stream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Any]: + yield from self._model(config).stream(input, config=config, **kwargs) + + async def astream( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Any]: + async for x in self._model(config).astream(input, config=config, **kwargs): + yield x + + def batch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Any]: + config = config or None + # If <= 1 config use the underlying models batch implementation. + if config is None or isinstance(config, dict) or len(config) <= 1: + if isinstance(config, list): + config = config[0] + return self._model(config).batch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + # If multiple configs default to Runnable.batch which uses executor to invoke + # in parallel. + else: + return super().batch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + + async def abatch( + self, + inputs: List[LanguageModelInput], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> List[Any]: + config = config or None + # If <= 1 config use the underlying models batch implementation. + if config is None or isinstance(config, dict) or len(config) <= 1: + if isinstance(config, list): + config = config[0] + return await self._model(config).abatch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + # If multiple configs default to Runnable.batch which uses executor to invoke + # in parallel. + else: + return await super().abatch( + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + + def batch_as_completed( + self, + inputs: Sequence[LanguageModelInput], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> Iterator[Tuple[int, Union[Any, Exception]]]: + config = config or None + # If <= 1 config use the underlying models batch implementation. + if config is None or isinstance(config, dict) or len(config) <= 1: + if isinstance(config, list): + config = config[0] + yield from self._model(cast(RunnableConfig, config)).batch_as_completed( # type: ignore[call-overload] + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + # If multiple configs default to Runnable.batch which uses executor to invoke + # in parallel. + else: + yield from super().batch_as_completed( # type: ignore[call-overload] + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ) + + async def abatch_as_completed( + self, + inputs: Sequence[LanguageModelInput], + config: Optional[Union[RunnableConfig, Sequence[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> AsyncIterator[Tuple[int, Any]]: + config = config or None + # If <= 1 config use the underlying models batch implementation. + if config is None or isinstance(config, dict) or len(config) <= 1: + if isinstance(config, list): + config = config[0] + async for x in self._model( + cast(RunnableConfig, config) + ).abatch_as_completed( # type: ignore[call-overload] + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ): + yield x + # If multiple configs default to Runnable.batch which uses executor to invoke + # in parallel. + else: + async for x in super().abatch_as_completed( # type: ignore[call-overload] + inputs, config=config, return_exceptions=return_exceptions, **kwargs + ): + yield x + + def transform( + self, + input: Iterator[LanguageModelInput], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Any]: + for x in self._model(config).transform(input, config=config, **kwargs): + yield x + + async def atransform( + self, + input: AsyncIterator[LanguageModelInput], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Any]: + async for x in self._model(config).atransform(input, config=config, **kwargs): + yield x + + @overload + def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: Literal[True] = True, + with_streamed_output_list: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[RunLogPatch]: ... + + @overload + def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: Literal[False], + with_streamed_output_list: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[RunLog]: ... + + async def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + diff: bool = True, + with_streamed_output_list: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: + async for x in self._model(config).astream_log( # type: ignore[call-overload, misc] + input, + config=config, + diff=diff, + with_streamed_output_list=with_streamed_output_list, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_tags=exclude_tags, + exclude_types=exclude_types, + exclude_names=exclude_names, + **kwargs, + ): + yield x + + async def astream_events( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + version: Literal["v1", "v2"], + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> AsyncIterator[StreamEvent]: + async for x in self._model(config).astream_events( + input, + config=config, + version=version, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_tags=exclude_tags, + exclude_types=exclude_types, + exclude_names=exclude_names, + **kwargs, + ): + yield x + + # Explicitly added to satisfy downstream linters. + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + return self.__getattr__("bind_tools")(tools, **kwargs) + + # Explicitly added to satisfy downstream linters. + def with_structured_output( + self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + return self.__getattr__("with_structured_output")(schema, **kwargs) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index a406f2ed012..36487ca1407 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1760,7 +1760,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.13" +version = "0.2.18" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -1784,7 +1784,7 @@ url = "../core" [[package]] name = "langchain-openai" -version = "0.1.15" +version = "0.1.16" description = "An integration package connecting OpenAI and LangChain" optional = true python-versions = ">=3.8.1,<4.0" @@ -1792,7 +1792,7 @@ files = [] develop = true [package.dependencies] -langchain-core = "^0.2.13" +langchain-core = "^0.2.17" openai = "^1.32.0" tiktoken = ">=0.7,<1" @@ -1800,6 +1800,24 @@ tiktoken = ">=0.7,<1" type = "directory" url = "../partners/openai" +[[package]] +name = "langchain-standard-tests" +version = "0.1.1" +description = "Standard tests for LangChain implementations" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [] +develop = true + +[package.dependencies] +httpx = "^0.27.0" +langchain-core = ">=0.1.40,<0.3" +pytest = ">=7,<9" + +[package.source] +type = "directory" +url = "../standard-tests" + [[package]] name = "langchain-text-splitters" version = "0.2.2" @@ -2490,8 +2508,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4111,20 +4129,6 @@ files = [ cryptography = ">=35.0.0" types-pyOpenSSL = "*" -[[package]] -name = "types-requests" -version = "2.31.0.6" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.7" -files = [ - {file = "types-requests-2.31.0.6.tar.gz", hash = "sha256:cd74ce3b53c461f1228a9b783929ac73a666658f223e28ed29753771477b3bd0"}, - {file = "types_requests-2.31.0.6-py3-none-any.whl", hash = "sha256:a2db9cb228a81da8348b49ad6db3f5519452dd20a9c1e1a868c83c5fe88fd1a9"}, -] - -[package.dependencies] -types-urllib3 = "*" - [[package]] name = "types-requests" version = "2.32.0.20240622" @@ -4161,17 +4165,6 @@ files = [ {file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"}, ] -[[package]] -name = "types-urllib3" -version = "1.26.25.14" -description = "Typing stubs for urllib3" -optional = false -python-versions = "*" -files = [ - {file = "types-urllib3-1.26.25.14.tar.gz", hash = "sha256:229b7f577c951b8c1b92c1bc2b2fdb0b49847bd2af6d1cc2a2e3dd340f3bda8f"}, - {file = "types_urllib3-1.26.25.14-py3-none-any.whl", hash = "sha256:9683bbb7fb72e32bfe9d2be6e04875fbe1b3eeec3cbb4ea231435aa7fd6b4f0e"}, -] - [[package]] name = "typing-extensions" version = "4.12.2" @@ -4208,22 +4201,6 @@ files = [ [package.extras] dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake8-commas", "flake8-comprehensions", "flake8-continuation", "flake8-datetimez", "flake8-docstrings", "flake8-import-order", "flake8-literal", "flake8-modern-annotations", "flake8-noqa", "flake8-pyproject", "flake8-requirements", "flake8-typechecking-import", "flake8-use-fstring", "mypy", "pep8-naming", "types-PyYAML"] -[[package]] -name = "urllib3" -version = "1.26.19" -description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" -files = [ - {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, - {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, -] - -[package.extras] -brotli = ["brotli (==1.0.9)", "brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] - [[package]] name = "urllib3" version = "2.2.2" @@ -4241,6 +4218,23 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "vcrpy" +version = "4.3.0" +description = "Automatically mock your HTTP interactions to simplify and speed up testing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "vcrpy-4.3.0-py2.py3-none-any.whl", hash = "sha256:8fbd4be412e8a7f35f623dd61034e6380a1c8dbd0edf6e87277a3289f6e98093"}, + {file = "vcrpy-4.3.0.tar.gz", hash = "sha256:49c270ce67e826dba027d83e20d25b67a5885487697e97bca6dbdf53d750a0ac"}, +] + +[package.dependencies] +PyYAML = "*" +six = ">=1.5" +wrapt = "*" +yarl = "*" + [[package]] name = "vcrpy" version = "6.0.1" @@ -4253,7 +4247,6 @@ files = [ [package.dependencies] PyYAML = "*" -urllib3 = {version = "<2", markers = "platform_python_implementation == \"PyPy\" or python_version < \"3.10\""} wrapt = "*" yarl = "*" @@ -4568,4 +4561,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "30237e9280ade99d7c7741aec1b3d38a8e1ccb24a3d0c4380d48ae80ab86a136" +content-hash = "14ebfabffa095e7619e9646bf56bc166d18c1c975b65e301bb6163c4e8eecaac" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index fb96d329390..730967f44d0 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -95,6 +95,10 @@ pytest-socket = "^0.6.0" syrupy = "^4.0.2" requests-mock = "^1.11.0" +[tool.poetry.group.test.dependencies.langchain-standard-tests] +path = "../standard-tests" +develop = true + [tool.poetry.group.codespell.dependencies] codespell = "^2.2.0" diff --git a/libs/langchain/tests/integration_tests/chat_models/__init__.py b/libs/langchain/tests/integration_tests/chat_models/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/integration_tests/chat_models/test_base.py b/libs/langchain/tests/integration_tests/chat_models/test_base.py new file mode 100644 index 00000000000..cda11263ddf --- /dev/null +++ b/libs/langchain/tests/integration_tests/chat_models/test_base.py @@ -0,0 +1,59 @@ +from typing import Type, cast + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import RunnableConfig +from langchain_standard_tests.integration_tests import ChatModelIntegrationTests + +from langchain.chat_models import init_chat_model + + +class multiply(BaseModel): + """Product of two ints.""" + + x: int + y: int + + +@pytest.mark.requires("langchain_openai", "langchain_anthropic") +async def test_init_chat_model_chain() -> None: + model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar") + model_with_tools = model.bind_tools([multiply]) + + model_with_config = model_with_tools.with_config( + RunnableConfig(tags=["foo"]), + configurable={"bar_model": "claude-3-sonnet-20240229"}, + ) + prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")]) + chain = prompt | model_with_config + output = chain.invoke({"input": "bar"}) + assert isinstance(output, AIMessage) + events = [] + async for event in chain.astream_events({"input": "bar"}, version="v2"): + events.append(event) + assert events + + +class TestStandard(ChatModelIntegrationTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return cast(Type[BaseChatModel], init_chat_model) + + @property + def chat_model_params(self) -> dict: + return {"model": "gpt-4o", "configurable_fields": "any"} + + @property + def supports_image_inputs(self) -> bool: + return True + + @property + def has_tool_calling(self) -> bool: + return True + + @property + def has_structured_output(self) -> bool: + return True diff --git a/libs/langchain/tests/unit_tests/chat_models/test_base.py b/libs/langchain/tests/unit_tests/chat_models/test_base.py index 15aa119ff70..7e162ad5398 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_base.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_base.py @@ -1,4 +1,11 @@ +import os +from unittest import mock + import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnableConfig, RunnableSequence from langchain.chat_models.base import __all__, init_chat_model @@ -34,14 +41,156 @@ def test_all_imports() -> None: ], ) def test_init_chat_model(model_name: str, model_provider: str) -> None: - init_chat_model(model_name, model_provider=model_provider, api_key="foo") + _: BaseChatModel = init_chat_model( + model_name, model_provider=model_provider, api_key="foo" + ) def test_init_missing_dep() -> None: with pytest.raises(ImportError): - init_chat_model("gpt-4o", model_provider="openai") + init_chat_model("mixtral-8x7b-32768", model_provider="groq") def test_init_unknown_provider() -> None: with pytest.raises(ValueError): init_chat_model("foo", model_provider="bar") + + +@pytest.mark.requires("langchain_openai") +@mock.patch.dict( + os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True +) +def test_configurable() -> None: + model = init_chat_model() + + for method in ( + "invoke", + "ainvoke", + "batch", + "abatch", + "stream", + "astream", + "batch_as_completed", + "abatch_as_completed", + ): + assert hasattr(model, method) + + # Doesn't have access non-configurable, non-declarative methods until a config is + # provided. + for method in ("get_num_tokens", "get_num_tokens_from_messages"): + with pytest.raises(AttributeError): + getattr(model, method) + + # Can call declarative methods even without a default model. + model_with_tools = model.bind_tools( + [{"name": "foo", "description": "foo", "parameters": {}}] + ) + + # Check that original model wasn't mutated by declarative operation. + assert model._queued_declarative_operations == [] + + # Can iteratively call declarative methods. + model_with_config = model_with_tools.with_config( + RunnableConfig(tags=["foo"]), configurable={"model": "gpt-4o"} + ) + assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined] + + for method in ("get_num_tokens", "get_num_tokens_from_messages"): + assert hasattr(model_with_config, method) + + assert model_with_config.dict() == { # type: ignore[attr-defined] + "name": None, + "bound": { + "model_name": "gpt-4o", + "model": "gpt-4o", + "stream": False, + "n": 1, + "temperature": 0.7, + "presence_penalty": None, + "frequency_penalty": None, + "seed": None, + "top_p": None, + "logprobs": False, + "top_logprobs": None, + "logit_bias": None, + "_type": "openai-chat", + }, + "kwargs": { + "tools": [ + { + "type": "function", + "function": {"name": "foo", "description": "foo", "parameters": {}}, + } + ] + }, + "config": {"tags": ["foo"], "configurable": {}}, + "config_factories": [], + "custom_input_type": None, + "custom_output_type": None, + } + + +@pytest.mark.requires("langchain_openai", "langchain_anthropic") +@mock.patch.dict( + os.environ, {"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "foo"}, clear=True +) +def test_configurable_with_default() -> None: + model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar") + for method in ( + "invoke", + "ainvoke", + "batch", + "abatch", + "stream", + "astream", + "batch_as_completed", + "abatch_as_completed", + ): + assert hasattr(model, method) + + # Does have access non-configurable, non-declarative methods since default params + # are provided. + for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"): + assert hasattr(model, method) + + assert model.model_name == "gpt-4o" # type: ignore[attr-defined] + + model_with_tools = model.bind_tools( + [{"name": "foo", "description": "foo", "parameters": {}}] + ) + + model_with_config = model_with_tools.with_config( + RunnableConfig(tags=["foo"]), + configurable={"bar_model": "claude-3-sonnet-20240229"}, + ) + + assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined] + # Anthropic defaults to using `transformers` for token counting. + with pytest.raises(ImportError): + model_with_config.get_num_tokens_from_messages([(HumanMessage("foo"))]) # type: ignore[attr-defined] + + assert model_with_config.dict() == { # type: ignore[attr-defined] + "name": None, + "bound": { + "model": "claude-3-sonnet-20240229", + "max_tokens": 1024, + "temperature": None, + "top_k": None, + "top_p": None, + "model_kwargs": {}, + "streaming": False, + "max_retries": 2, + "default_request_timeout": None, + "_type": "anthropic-chat", + }, + "kwargs": { + "tools": [{"name": "foo", "description": "foo", "input_schema": {}}] + }, + "config": {"tags": ["foo"], "configurable": {}}, + "config_factories": [], + "custom_input_type": None, + "custom_output_type": None, + } + prompt = ChatPromptTemplate.from_messages([("system", "foo")]) + chain = prompt | model_with_config + assert isinstance(chain, RunnableSequence) diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index b182bade206..df04f6bd8ac 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -79,6 +79,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: "duckdb-engine", "freezegun", "langchain-core", + "langchain-standard-tests", "langchain-text-splitters", "langchain-openai", "lark",