langchain[patch]: init_chat_model xai support (#29849)

This commit is contained in:
Bagatur 2025-02-17 09:45:39 -08:00 committed by GitHub
parent 1a55da9ff4
commit 1acf57e9bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -98,9 +98,12 @@ def init_chat_model(
installed.
Args:
model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229".
model_provider: The model provider. Supported model_provider values and the
corresponding integration package are:
model: The name of the model, e.g. "o3-mini", "claude-3-5-sonnet-latest". You can
also specify model and model provider in a single argument using
'{model_provider}:{model}' format, e.g. "openai:o1".
model_provider: The model provider if not specified as part of model arg (see
above). Supported model_provider values and the corresponding integration
package are:
- 'openai' -> langchain-openai
- 'anthropic' -> langchain-anthropic
@ -118,8 +121,9 @@ def init_chat_model(
- 'ollama' -> langchain-ollama
- 'google_anthropic_vertex' -> langchain-google-vertexai
- 'deepseek' -> langchain-deepseek
- 'ibm' -> langchain-ibm
- 'ibm' -> langchain-ibm
- 'nvidia' -> langchain-nvidia-ai-endpoints
- 'xai' -> langchain-xai
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
@ -131,6 +135,8 @@ def init_chat_model(
- 'command...' -> 'cohere'
- 'accounts/fireworks...' -> 'fireworks'
- 'mistral...' -> 'mistralai'
- 'deepseek...' -> 'deepseek'
- 'grok...' -> 'xai'
configurable_fields: Which model parameters are
configurable:
@ -182,13 +188,13 @@ def init_chat_model(
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain.chat_models import init_chat_model
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0)
gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0)
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
gpt_4o.invoke("what's your name")
claude_opus.invoke("what's your name")
gemini_15.invoke("what's your name")
o3_mini.invoke("what's your name")
claude_sonnet.invoke("what's your name")
gemini_2_flash.invoke("what's your name")
.. dropdown:: Partially configurable model with no default
@ -209,7 +215,7 @@ def init_chat_model(
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
)
# claude-3.5 sonnet response
@ -221,8 +227,7 @@ def init_chat_model(
from langchain.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"gpt-4o",
model_provider="openai",
"openai:gpt-4o",
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0
@ -235,8 +240,7 @@ def init_chat_model(
"what's your name",
config={
"configurable": {
"foo_model": "claude-3-5-sonnet-20240620",
"foo_model_provider": "anthropic",
"foo_model": "anthropic:claude-3-5-sonnet-20240620",
"foo_temperature": 0.6
}
}
@ -290,18 +294,22 @@ def init_chat_model(
.. versionchanged:: 0.2.12
Support for ChatOllama via langchain-ollama package added
Support for Ollama via langchain-ollama package added
(langchain_ollama.ChatOllama). Previously,
the now-deprecated langchain-community version of Ollama was imported
(langchain_community.chat_models.ChatOllama).
Support for langchain_aws.ChatBedrockConverse added
Support for AWS Bedrock models via the Converse API added
(model_provider="bedrock_converse").
.. versionchanged:: 0.3.5
Out of beta.
.. versionchanged:: 0.3.19
Support for Deepseek, IBM, Nvidia, and xAI models added.
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
@ -434,6 +442,11 @@ def _init_chat_model_helper(
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
elif model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
@ -478,6 +491,10 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return "bedrock"
elif model_name.startswith("mistral"):
return "mistralai"
elif model_name.startswith("deepseek"):
return "deepseek"
elif model_name.startswith("grok"):
return "xai"
else:
return None