mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
langchain[patch]: init_chat_model xai support (#29849)
This commit is contained in:
parent
1a55da9ff4
commit
1acf57e9bd
@ -98,9 +98,12 @@ def init_chat_model(
|
||||
installed.
|
||||
|
||||
Args:
|
||||
model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229".
|
||||
model_provider: The model provider. Supported model_provider values and the
|
||||
corresponding integration package are:
|
||||
model: The name of the model, e.g. "o3-mini", "claude-3-5-sonnet-latest". You can
|
||||
also specify model and model provider in a single argument using
|
||||
'{model_provider}:{model}' format, e.g. "openai:o1".
|
||||
model_provider: The model provider if not specified as part of model arg (see
|
||||
above). Supported model_provider values and the corresponding integration
|
||||
package are:
|
||||
|
||||
- 'openai' -> langchain-openai
|
||||
- 'anthropic' -> langchain-anthropic
|
||||
@ -118,8 +121,9 @@ def init_chat_model(
|
||||
- 'ollama' -> langchain-ollama
|
||||
- 'google_anthropic_vertex' -> langchain-google-vertexai
|
||||
- 'deepseek' -> langchain-deepseek
|
||||
- 'ibm' -> langchain-ibm
|
||||
- 'ibm' -> langchain-ibm
|
||||
- 'nvidia' -> langchain-nvidia-ai-endpoints
|
||||
- 'xai' -> langchain-xai
|
||||
|
||||
Will attempt to infer model_provider from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
@ -131,6 +135,8 @@ def init_chat_model(
|
||||
- 'command...' -> 'cohere'
|
||||
- 'accounts/fireworks...' -> 'fireworks'
|
||||
- 'mistral...' -> 'mistralai'
|
||||
- 'deepseek...' -> 'deepseek'
|
||||
- 'grok...' -> 'xai'
|
||||
configurable_fields: Which model parameters are
|
||||
configurable:
|
||||
|
||||
@ -182,13 +188,13 @@ def init_chat_model(
|
||||
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
|
||||
claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0)
|
||||
gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0)
|
||||
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
|
||||
claude_sonnet = init_chat_model("anthropic:claude-3-5-sonnet-latest", temperature=0)
|
||||
gemini_2_flash = init_chat_model("google_vertexai:gemini-2.0-flash", temperature=0)
|
||||
|
||||
gpt_4o.invoke("what's your name")
|
||||
claude_opus.invoke("what's your name")
|
||||
gemini_15.invoke("what's your name")
|
||||
o3_mini.invoke("what's your name")
|
||||
claude_sonnet.invoke("what's your name")
|
||||
gemini_2_flash.invoke("what's your name")
|
||||
|
||||
|
||||
.. dropdown:: Partially configurable model with no default
|
||||
@ -209,7 +215,7 @@ def init_chat_model(
|
||||
|
||||
configurable_model.invoke(
|
||||
"what's your name",
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-20240620"}}
|
||||
config={"configurable": {"model": "claude-3-5-sonnet-latest"}}
|
||||
)
|
||||
# claude-3.5 sonnet response
|
||||
|
||||
@ -221,8 +227,7 @@ def init_chat_model(
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
configurable_model_with_default = init_chat_model(
|
||||
"gpt-4o",
|
||||
model_provider="openai",
|
||||
"openai:gpt-4o",
|
||||
configurable_fields="any", # this allows us to configure other params like temperature, max_tokens, etc at runtime.
|
||||
config_prefix="foo",
|
||||
temperature=0
|
||||
@ -235,8 +240,7 @@ def init_chat_model(
|
||||
"what's your name",
|
||||
config={
|
||||
"configurable": {
|
||||
"foo_model": "claude-3-5-sonnet-20240620",
|
||||
"foo_model_provider": "anthropic",
|
||||
"foo_model": "anthropic:claude-3-5-sonnet-20240620",
|
||||
"foo_temperature": 0.6
|
||||
}
|
||||
}
|
||||
@ -290,18 +294,22 @@ def init_chat_model(
|
||||
|
||||
.. versionchanged:: 0.2.12
|
||||
|
||||
Support for ChatOllama via langchain-ollama package added
|
||||
Support for Ollama via langchain-ollama package added
|
||||
(langchain_ollama.ChatOllama). Previously,
|
||||
the now-deprecated langchain-community version of Ollama was imported
|
||||
(langchain_community.chat_models.ChatOllama).
|
||||
|
||||
Support for langchain_aws.ChatBedrockConverse added
|
||||
Support for AWS Bedrock models via the Converse API added
|
||||
(model_provider="bedrock_converse").
|
||||
|
||||
.. versionchanged:: 0.3.5
|
||||
|
||||
Out of beta.
|
||||
|
||||
.. versionchanged:: 0.3.19
|
||||
|
||||
Support for Deepseek, IBM, Nvidia, and xAI models added.
|
||||
|
||||
""" # noqa: E501
|
||||
if not model and not configurable_fields:
|
||||
configurable_fields = ("model", "model_provider")
|
||||
@ -434,6 +442,11 @@ def _init_chat_model_helper(
|
||||
from langchain_ibm import ChatWatsonx
|
||||
|
||||
return ChatWatsonx(model_id=model, **kwargs)
|
||||
elif model_provider == "xai":
|
||||
_check_pkg("langchain_xai")
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
return ChatXAI(model=model, **kwargs)
|
||||
else:
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
raise ValueError(
|
||||
@ -478,6 +491,10 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
return "bedrock"
|
||||
elif model_name.startswith("mistral"):
|
||||
return "mistralai"
|
||||
elif model_name.startswith("deepseek"):
|
||||
return "deepseek"
|
||||
elif model_name.startswith("grok"):
|
||||
return "xai"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user