mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
langchain[minor]: add universal init_model (#22039)
decisions to discuss - only chat models - model_provider isn't based on any existing values like llm-type, package names, class names - implemented as function not as a wrapper ChatModel - function name (init_model) - in langchain as opposed to community or core - marked beta
This commit is contained in:
@@ -21,6 +21,7 @@ import warnings
|
||||
from langchain_core._api import LangChainDeprecationWarning
|
||||
|
||||
from langchain._api.interactive_env import is_interactive_env
|
||||
from langchain.chat_models.base import init_chat_model
|
||||
|
||||
|
||||
def __getattr__(name: str) -> None:
|
||||
@@ -41,6 +42,7 @@ def __getattr__(name: str) -> None:
|
||||
|
||||
|
||||
__all__ = [
|
||||
"init_chat_model",
|
||||
"ChatOpenAI",
|
||||
"BedrockChat",
|
||||
"AzureChatOpenAI",
|
||||
|
@@ -1,3 +1,7 @@
|
||||
from importlib import util
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
SimpleChatModel,
|
||||
@@ -10,4 +14,189 @@ __all__ = [
|
||||
"SimpleChatModel",
|
||||
"generate_from_stream",
|
||||
"agenerate_from_stream",
|
||||
"init_chat_model",
|
||||
]
|
||||
|
||||
|
||||
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
|
||||
# name to the supported list in the docstring below. Do *not* change the order of the
|
||||
# existing providers.
|
||||
@beta()
|
||||
def init_chat_model(
|
||||
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
|
||||
) -> BaseChatModel:
|
||||
"""Initialize a ChatModel from the model name and provider.
|
||||
|
||||
Must have the integration package corresponding to the model provider installed.
|
||||
|
||||
Args:
|
||||
model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229".
|
||||
model_provider: The model provider. Supported model_provider values and the
|
||||
corresponding integration package:
|
||||
- openai (langchain-openai)
|
||||
- anthropic (langchain-anthropic)
|
||||
- azure_openai (langchain-openai)
|
||||
- google_vertexai (langchain-google-vertexai)
|
||||
- google_genai (langchain-google-genai)
|
||||
- bedrock (langchain-aws)
|
||||
- cohere (langchain-cohere)
|
||||
- fireworks (langchain-fireworks)
|
||||
- together (langchain-together)
|
||||
- mistralai (langchain-mistralai)
|
||||
- huggingface (langchain-huggingface)
|
||||
- groq (langchain-groq)
|
||||
- ollama (langchain-community)
|
||||
|
||||
Will attempt to infer model_provider from model if not specified. The
|
||||
following providers will be inferred based on these model prefixes:
|
||||
- gpt-3... or gpt-4... -> openai
|
||||
- claude... -> anthropic
|
||||
- amazon.... -> bedrock
|
||||
- gemini... -> google_vertexai
|
||||
- command... -> cohere
|
||||
- accounts/fireworks... -> fireworks
|
||||
kwargs: Additional keyword args to pass to
|
||||
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
|
||||
|
||||
Returns:
|
||||
The BaseChatModel corresponding to the model_name and model_provider specified.
|
||||
|
||||
Raises:
|
||||
ValueError: If model_provider cannot be inferred or isn't supported.
|
||||
ImportError: If the model provider integration package is not installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
|
||||
claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0)
|
||||
gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0)
|
||||
|
||||
gpt_4o.invoke("what's your name")
|
||||
claude_opus.invoke("what's your name")
|
||||
gemini_15.invoke("what's your name")
|
||||
""" # noqa: E501
|
||||
model_provider = model_provider or _attempt_infer_model_provider(model)
|
||||
if not model_provider:
|
||||
raise ValueError(
|
||||
f"Unable to infer model provider for {model=}, please specify "
|
||||
f"model_provider directly."
|
||||
)
|
||||
model_provider = model_provider.replace("-", "_").lower()
|
||||
if model_provider == "openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(model=model, **kwargs)
|
||||
elif model_provider == "anthropic":
|
||||
_check_pkg("langchain_anthropic")
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
return ChatAnthropic(model=model, **kwargs)
|
||||
elif model_provider == "azure_openai":
|
||||
_check_pkg("langchain_openai")
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI(model=model, **kwargs)
|
||||
elif model_provider == "cohere":
|
||||
_check_pkg("langchain_cohere")
|
||||
from langchain_cohere import ChatCohere
|
||||
|
||||
return ChatCohere(model=model, **kwargs)
|
||||
elif model_provider == "google_vertexai":
|
||||
_check_pkg("langchain_google_vertexai")
|
||||
from langchain_google_vertexai import ChatVertexAI
|
||||
|
||||
return ChatVertexAI(model=model, **kwargs)
|
||||
elif model_provider == "google_genai":
|
||||
_check_pkg("langchain_google_genai")
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(model=model, **kwargs)
|
||||
elif model_provider == "fireworks":
|
||||
_check_pkg("langchain_fireworks")
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
return ChatFireworks(model=model, **kwargs)
|
||||
elif model_provider == "ollama":
|
||||
_check_pkg("langchain_community")
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
return ChatOllama(model=model, **kwargs)
|
||||
elif model_provider == "together":
|
||||
_check_pkg("langchain_together")
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
return ChatTogether(model=model, **kwargs)
|
||||
elif model_provider == "mistralai":
|
||||
_check_pkg("langchain_mistralai")
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
|
||||
return ChatMistralAI(model=model, **kwargs)
|
||||
elif model_provider == "huggingface":
|
||||
_check_pkg("langchain_huggingface")
|
||||
from langchain_huggingface import ChatHuggingFace
|
||||
|
||||
return ChatHuggingFace(model_id=model, **kwargs)
|
||||
elif model_provider == "groq":
|
||||
_check_pkg("langchain_groq")
|
||||
from langchain_groq import ChatGroq
|
||||
|
||||
return ChatGroq(model=model, **kwargs)
|
||||
elif model_provider == "bedrock":
|
||||
_check_pkg("langchain_aws")
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
# TODO: update to use model= once ChatBedrock supports
|
||||
return ChatBedrock(model_id=model, **kwargs)
|
||||
else:
|
||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||
raise ValueError(
|
||||
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
|
||||
f"{supported}"
|
||||
)
|
||||
|
||||
|
||||
_SUPPORTED_PROVIDERS = {
|
||||
"openai",
|
||||
"anthropic",
|
||||
"azure_openai",
|
||||
"cohere",
|
||||
"google_vertexai",
|
||||
"google_genai",
|
||||
"fireworks",
|
||||
"ollama",
|
||||
"together",
|
||||
"mistralai",
|
||||
"huggingface",
|
||||
"groq",
|
||||
"bedrock",
|
||||
}
|
||||
|
||||
|
||||
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"):
|
||||
return "openai"
|
||||
elif model_name.startswith("claude"):
|
||||
return "anthropic"
|
||||
elif model_name.startswith("command"):
|
||||
return "cohere"
|
||||
elif model_name.startswith("accounts/fireworks"):
|
||||
return "fireworks"
|
||||
elif model_name.startswith("gemini"):
|
||||
return "google_vertexai"
|
||||
elif model_name.startswith("amazon."):
|
||||
return "bedrock"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _check_pkg(pkg: str) -> None:
|
||||
if not util.find_spec(pkg):
|
||||
pkg_kebab = pkg.replace("_", "-")
|
||||
raise ImportError(
|
||||
f"Unable to import {pkg_kebab}. Please install with "
|
||||
f"`pip install -U {pkg_kebab}`"
|
||||
)
|
||||
|
3082
libs/langchain/poetry.lock
generated
3082
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,15 +22,11 @@ PyYAML = ">=5.3"
|
||||
numpy = "^1"
|
||||
aiohttp = "^3.8.3"
|
||||
tenacity = "^8.1.0"
|
||||
async-timeout = {version = "^4.0.0", python = "<3.11"}
|
||||
azure-core = {version = "^1.26.4", optional=true}
|
||||
tqdm = {version = ">=4.48.0", optional = true}
|
||||
openapi-pydantic = {version = "^0.3.2", optional = true}
|
||||
faiss-cpu = {version = "^1", optional = true}
|
||||
manifest-ml = {version = "^0.0.1", optional = true}
|
||||
transformers = {version = "^4", optional = true}
|
||||
beautifulsoup4 = {version = "^4", optional = true}
|
||||
torch = {version = ">=1,<3", optional = true}
|
||||
jinja2 = {version = "^3", optional = true}
|
||||
tiktoken = {version = ">=0.7,<1.0", optional = true, python=">=3.9"}
|
||||
qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
|
||||
cohere = {version = ">=4,<6", optional = true}
|
||||
@@ -38,77 +34,28 @@ openai = {version = "<2", optional = true}
|
||||
nlpcloud = {version = "^1", optional = true}
|
||||
huggingface_hub = {version = "^0", optional = true}
|
||||
sentence-transformers = {version = "^2", optional = true}
|
||||
arxiv = {version = "^1.4", optional = true}
|
||||
pypdf = {version = "^3.4.0", optional = true}
|
||||
aleph-alpha-client = {version="^2.15.0", optional = true}
|
||||
pgvector = {version = "^0.1.6", optional = true}
|
||||
async-timeout = {version = "^4.0.0", python = "<3.11"}
|
||||
azure-identity = {version = "^1.12.0", optional=true}
|
||||
atlassian-python-api = {version = "^3.36.0", optional=true}
|
||||
html2text = {version="^2020.1.16", optional=true}
|
||||
numexpr = {version="^2.8.6", optional=true}
|
||||
azure-cosmos = {version="^4.4.0b1", optional=true}
|
||||
jq = {version = "^1.4.1", optional = true}
|
||||
pdfminer-six = {version = "^20221105", optional = true}
|
||||
docarray = {version="^0.32.0", extras=["hnswlib"], optional=true}
|
||||
lxml = {version = ">=4.9.3,<6.0", optional = true}
|
||||
pymupdf = {version = "^1.22.3", optional = true}
|
||||
rapidocr-onnxruntime = {version = "^1.3.2", optional = true, python = ">=3.8.1,<3.12"}
|
||||
pypdfium2 = {version = "^4.10.0", optional = true}
|
||||
gql = {version = "^3.4.1", optional = true}
|
||||
pandas = {version = "^2.0.1", optional = true}
|
||||
telethon = {version = "^1.28.5", optional = true}
|
||||
chardet = {version="^5.1.0", optional=true}
|
||||
requests-toolbelt = {version = "^1.0.0", optional = true}
|
||||
openlm = {version = "^0.0.5", optional = true}
|
||||
scikit-learn = {version = "^1.2.2", optional = true}
|
||||
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
|
||||
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
|
||||
py-trello = {version = "^0.19.0", optional = true}
|
||||
bibtexparser = {version = "^1.4.0", optional = true}
|
||||
pyspark = {version = "^3.4.0", optional = true}
|
||||
clarifai = {version = ">=9.1.0", optional = true}
|
||||
mwparserfromhell = {version = "^0.6.4", optional = true}
|
||||
mwxml = {version = "^0.3.3", optional = true}
|
||||
azure-search-documents = {version = "11.4.0b8", optional = true}
|
||||
esprima = {version = "^4.0.1", optional = true}
|
||||
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
|
||||
psychicapi = {version = "^0.8.0", optional = true}
|
||||
cassio = {version = "^0.1.0", optional = true}
|
||||
sympy = {version = "^1.12", optional = true}
|
||||
rapidfuzz = {version = "^3.1.1", optional = true}
|
||||
jsonschema = {version = ">1", optional = true}
|
||||
rank-bm25 = {version = "^0.2.2", optional = true}
|
||||
geopandas = {version = "^0.13.1", optional = true}
|
||||
gitpython = {version = "^3.1.32", optional = true}
|
||||
feedparser = {version = "^6.0.10", optional = true}
|
||||
newspaper3k = {version = "^0.2.8", optional = true}
|
||||
xata = {version = "^1.0.0a7", optional = true}
|
||||
xmltodict = {version = "^0.13.0", optional = true}
|
||||
markdownify = {version = "^0.11.6", optional = true}
|
||||
assemblyai = {version = "^0.17.0", optional = true}
|
||||
dashvector = {version = "^1.0.1", optional = true}
|
||||
sqlite-vss = {version = "^0.1.2", optional = true}
|
||||
motor = {version = "^3.3.1", optional = true}
|
||||
timescale-vector = {version = "^0.0.1", optional = true}
|
||||
typer = {version= "^0.9.0", optional = true}
|
||||
anthropic = {version = "^0.3.11", optional = true}
|
||||
aiosqlite = {version = "^0.19.0", optional = true}
|
||||
rspace_client = {version = "^2.5.0", optional = true}
|
||||
upstash-redis = {version = "^0.15.0", optional = true}
|
||||
azure-ai-textanalytics = {version = "^5.3.0", optional = true}
|
||||
google-cloud-documentai = {version = "^2.20.1", optional = true}
|
||||
fireworks-ai = {version = "^0.9.0", optional = true}
|
||||
javelin-sdk = {version = "^0.1.8", optional = true}
|
||||
hologres-vector = {version = "^0.0.6", optional = true}
|
||||
praw = {version = "^7.7.1", optional = true}
|
||||
msal = {version = "^1.25.0", optional = true}
|
||||
databricks-vectorsearch = {version = "^0.21", optional = true}
|
||||
couchbase = {version = "^4.1.9", optional = true}
|
||||
dgml-utils = {version = "^0.3.0", optional = true}
|
||||
datasets = {version = "^2.15.0", optional = true}
|
||||
langchain-openai = {version = "^0.1", optional = true}
|
||||
rdflib = {version = "7.0.0", optional = true}
|
||||
langchain-openai = {version = "^0", optional = true}
|
||||
langchain-anthropic = {version = "^0", optional = true}
|
||||
langchain-fireworks = {version = "^0", optional = true}
|
||||
langchain-together = {version = "^0", optional = true}
|
||||
langchain-mistralai = {version = "^0", optional = true}
|
||||
langchain-groq = {version = "^0", optional = true}
|
||||
jsonschema = {version = "^4.22.0", optional = true}
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@@ -162,11 +109,8 @@ optional = true
|
||||
# https://python.langchain.com/docs/contributing/code#working-with-optional-dependencies
|
||||
pytest-vcr = "^1.0.2"
|
||||
wrapt = "^1.15.0"
|
||||
openai = "^1"
|
||||
python-dotenv = "^1.0.0"
|
||||
cassio = "^0.1.0"
|
||||
tiktoken = ">=0.7,<1"
|
||||
anthropic = "^0.3.11"
|
||||
langchain-core = {path = "../core", develop = true}
|
||||
langchain-text-splitters = {path = "../text-splitters", develop = true}
|
||||
langchainhub = "^0.1.16"
|
||||
@@ -229,74 +173,18 @@ cli = ["typer"]
|
||||
# Please use new-line on formatting to make it easier to add new packages without
|
||||
# merge-conflicts
|
||||
extended_testing = [
|
||||
"aleph-alpha-client",
|
||||
"aiosqlite",
|
||||
"assemblyai",
|
||||
"beautifulsoup4",
|
||||
"bibtexparser",
|
||||
"cassio",
|
||||
"chardet",
|
||||
"datasets",
|
||||
"google-cloud-documentai",
|
||||
"esprima",
|
||||
"jq",
|
||||
"pdfminer-six",
|
||||
"pgvector",
|
||||
"pypdf",
|
||||
"pymupdf",
|
||||
"pypdfium2",
|
||||
"tqdm",
|
||||
"lxml",
|
||||
"atlassian-python-api",
|
||||
"mwparserfromhell",
|
||||
"mwxml",
|
||||
"msal",
|
||||
"pandas",
|
||||
"telethon",
|
||||
"psychicapi",
|
||||
"gql",
|
||||
"requests-toolbelt",
|
||||
"html2text",
|
||||
"numexpr",
|
||||
"py-trello",
|
||||
"scikit-learn",
|
||||
"streamlit",
|
||||
"pyspark",
|
||||
"openai",
|
||||
"sympy",
|
||||
"rapidfuzz",
|
||||
"jsonschema",
|
||||
"openai",
|
||||
"rank-bm25",
|
||||
"geopandas",
|
||||
"jinja2",
|
||||
"gitpython",
|
||||
"newspaper3k",
|
||||
"feedparser",
|
||||
"xata",
|
||||
"xmltodict",
|
||||
"faiss-cpu",
|
||||
"openapi-pydantic",
|
||||
"markdownify",
|
||||
"arxiv",
|
||||
"dashvector",
|
||||
"sqlite-vss",
|
||||
"rapidocr-onnxruntime",
|
||||
"motor",
|
||||
"timescale-vector",
|
||||
"anthropic",
|
||||
"upstash-redis",
|
||||
"rspace_client",
|
||||
"fireworks-ai",
|
||||
"javelin-sdk",
|
||||
"hologres-vector",
|
||||
"praw",
|
||||
"databricks-vectorsearch",
|
||||
"couchbase",
|
||||
"dgml-utils",
|
||||
"cohere",
|
||||
"langchain-openai",
|
||||
"rdflib",
|
||||
"langchain-openai",
|
||||
"langchain-anthropic",
|
||||
"langchain-fireworks",
|
||||
"langchain-together",
|
||||
"langchain-mistralai",
|
||||
"langchain-groq",
|
||||
"openai",
|
||||
"tiktoken",
|
||||
"numexpr",
|
||||
"rapidfuzz",
|
||||
"aiosqlite",
|
||||
"jsonschema",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
|
@@ -1,12 +1,47 @@
|
||||
from langchain.chat_models.base import __all__
|
||||
import pytest
|
||||
|
||||
from langchain.chat_models.base import __all__, init_chat_model
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"BaseChatModel",
|
||||
"SimpleChatModel",
|
||||
"agenerate_from_stream",
|
||||
"generate_from_stream",
|
||||
"init_chat_model",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(__all__) == set(EXPECTED_ALL)
|
||||
|
||||
|
||||
@pytest.mark.requires(
|
||||
"langchain_openai",
|
||||
"langchain_anthropic",
|
||||
"langchain_fireworks",
|
||||
"langchain_together",
|
||||
"langchain_mistralai",
|
||||
"langchain_groq",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
["model_name", "model_provider"],
|
||||
[
|
||||
("gpt-4o", "openai"),
|
||||
("claude-3-opus-20240229", "anthropic"),
|
||||
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
|
||||
("meta-llama/Llama-3-8b-chat-hf", "together"),
|
||||
("mixtral-8x7b-32768", "groq"),
|
||||
],
|
||||
)
|
||||
def test_init_chat_model(model_name: str, model_provider: str) -> None:
|
||||
init_chat_model(model_name, model_provider=model_provider, api_key="foo")
|
||||
|
||||
|
||||
def test_init_missing_dep() -> None:
|
||||
with pytest.raises(ImportError):
|
||||
init_chat_model("gpt-4o", model_provider="openai")
|
||||
|
||||
|
||||
def test_init_unknown_provider() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
init_chat_model("foo", model_provider="bar")
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from langchain import chat_models
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"init_chat_model",
|
||||
"ChatOpenAI",
|
||||
"BedrockChat",
|
||||
"AzureChatOpenAI",
|
||||
|
Reference in New Issue
Block a user