From 1acf57e9bd2dcb1ad6317d24625e5828f65b3812 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 17 Feb 2025 09:45:39 -0800 Subject: [PATCH] langchain[patch]: init_chat_model xai support (#29849) --- libs/langchain/langchain/chat_models/base.py | 51 +++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 1b73f153d82..6a56418f13e 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -98,9 +98,12 @@ def init_chat_model( installed. Args: - model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229". - model_provider: The model provider. Supported model_provider values and the - corresponding integration package are: + model: The name of the model, e.g. "o3-mini", "claude-3-5-sonnet-latest". You can + also specify model and model provider in a single argument using + '{model_provider}:{model}' format, e.g. "openai:o1". + model_provider: The model provider if not specified as part of model arg (see + above). Supported model_provider values and the corresponding integration + package are: - 'openai' -> langchain-openai - 'anthropic' -> langchain-anthropic @@ -118,8 +121,9 @@ def init_chat_model( - 'ollama' -> langchain-ollama - 'google_anthropic_vertex' -> langchain-google-vertexai - 'deepseek' -> langchain-deepseek - - 'ibm' -> langchain-ibm + - 'ibm' -> langchain-ibm - 'nvidia' -> langchain-nvidia-ai-endpoints + - 'xai' -> langchain-xai Will attempt to infer model_provider from model if not specified. The following providers will be inferred based on these model prefixes: @@ -131,6 +135,8 @@ def init_chat_model( - 'command...' -> 'cohere' - 'accounts/fireworks...' -> 'fireworks' - 'mistral...' -> 'mistralai' + - 'deepseek...' -> 'deepseek' + - 'grok...' -> 'xai' configurable_fields: Which model parameters are configurable: @@ -182,13 +188,13 @@ def init_chat_model( # pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai from langchain.chat_models import init_chat_model - gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0) - claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0) - gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0) + o3_mini = init_chat_model("openai:o3-mini", temperature=0) + claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0) + gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0) - gpt_4o.invoke("what's your name") - claude_opus.invoke("what's your name") - gemini_15.invoke("what's your name") + o3_mini.invoke("what's your name") + claude_sonnet.invoke("what's your name") + gemini_2_flash.invoke("what's your name") .. dropdown:: Partially configurable model with no default @@ -209,7 +215,7 @@ def init_chat_model( configurable_model.invoke( "what's your name", - config={"configurable": {"model": "claude-3-5-sonnet-20240620"}} + config={"configurable": {"model": "claude-3-5-sonnet-latest"}} ) # claude-3.5 sonnet response @@ -221,8 +227,7 @@ def init_chat_model( from langchain.chat_models import init_chat_model configurable_model_with_default = init_chat_model( - "gpt-4o", - model_provider="openai", + "openai:gpt-4o", configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime. config_prefix="foo", temperature=0 @@ -235,8 +240,7 @@ def init_chat_model( "what's your name", config={ "configurable": { - "foo_model": "claude-3-5-sonnet-20240620", - "foo_model_provider": "anthropic", + "foo_model": "anthropic:claude-3-5-sonnet-20240620", "foo_temperature": 0.6 } } @@ -290,18 +294,22 @@ def init_chat_model( .. versionchanged:: 0.2.12 - Support for ChatOllama via langchain-ollama package added + Support for Ollama via langchain-ollama package added (langchain_ollama.ChatOllama). Previously, the now-deprecated langchain-community version of Ollama was imported (langchain_community.chat_models.ChatOllama). - Support for langchain_aws.ChatBedrockConverse added + Support for AWS Bedrock models via the Converse API added (model_provider="bedrock_converse"). .. versionchanged:: 0.3.5 Out of beta. + .. versionchanged:: 0.3.19 + + Support for Deepseek, IBM, Nvidia, and xAI models added. + """ # noqa: E501 if not model and not configurable_fields: configurable_fields = ("model", "model_provider") @@ -434,6 +442,11 @@ def _init_chat_model_helper( from langchain_ibm import ChatWatsonx return ChatWatsonx(model_id=model, **kwargs) + elif model_provider == "xai": + _check_pkg("langchain_xai") + from langchain_xai import ChatXAI + + return ChatXAI(model=model, **kwargs) else: supported = ", ".join(_SUPPORTED_PROVIDERS) raise ValueError( @@ -478,6 +491,10 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]: return "bedrock" elif model_name.startswith("mistral"): return "mistralai" + elif model_name.startswith("deepseek"): + return "deepseek" + elif model_name.startswith("grok"): + return "xai" else: return None