langchain[minor]: add universal init_model (#22039)

decisions to discuss
- only chat models
- model_provider isn't based on any existing values like llm-type,
package names, class names
- implemented as function not as a wrapper ChatModel
- function name (init_model)
- in langchain as opposed to community or core
- marked beta
This commit is contained in:
Bagatur
2024-06-05 14:39:40 -07:00
committed by GitHub
parent 67012c2558
commit 1a911018bc
8 changed files with 627 additions and 2994 deletions

View File

@@ -21,6 +21,7 @@ import warnings
from langchain_core._api import LangChainDeprecationWarning
from langchain._api.interactive_env import is_interactive_env
from langchain.chat_models.base import init_chat_model
def __getattr__(name: str) -> None:
@@ -41,6 +42,7 @@ def __getattr__(name: str) -> None:
__all__ = [
"init_chat_model",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",

View File

@@ -1,3 +1,7 @@
from importlib import util
from typing import Any, Optional
from langchain_core._api import beta
from langchain_core.language_models.chat_models import (
BaseChatModel,
SimpleChatModel,
@@ -10,4 +14,189 @@ __all__ = [
"SimpleChatModel",
"generate_from_stream",
"agenerate_from_stream",
"init_chat_model",
]
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
@beta()
def init_chat_model(
model: str, *, model_provider: Optional[str] = None, **kwargs: Any
) -> BaseChatModel:
"""Initialize a ChatModel from the model name and provider.
Must have the integration package corresponding to the model provider installed.
Args:
model: The name of the model, e.g. "gpt-4o", "claude-3-opus-20240229".
model_provider: The model provider. Supported model_provider values and the
corresponding integration package:
- openai (langchain-openai)
- anthropic (langchain-anthropic)
- azure_openai (langchain-openai)
- google_vertexai (langchain-google-vertexai)
- google_genai (langchain-google-genai)
- bedrock (langchain-aws)
- cohere (langchain-cohere)
- fireworks (langchain-fireworks)
- together (langchain-together)
- mistralai (langchain-mistralai)
- huggingface (langchain-huggingface)
- groq (langchain-groq)
- ollama (langchain-community)
Will attempt to infer model_provider from model if not specified. The
following providers will be inferred based on these model prefixes:
- gpt-3... or gpt-4... -> openai
- claude... -> anthropic
- amazon.... -> bedrock
- gemini... -> google_vertexai
- command... -> cohere
- accounts/fireworks... -> fireworks
kwargs: Additional keyword args to pass to
``<<selected ChatModel>>.__init__(model=model_name, **kwargs)``.
Returns:
The BaseChatModel corresponding to the model_name and model_provider specified.
Raises:
ValueError: If model_provider cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
Example:
.. code-block:: python
from langchain.chat_models import init_chat_model
gpt_4o = init_chat_model("gpt-4o", model_provider="openai", temperature=0)
claude_opus = init_chat_model("claude-3-opus-20240229", model_provider="anthropic", temperature=0)
gemini_15 = init_chat_model("gemini-1.5-pro", model_provider="google_vertexai", temperature=0)
gpt_4o.invoke("what's your name")
claude_opus.invoke("what's your name")
gemini_15.invoke("what's your name")
""" # noqa: E501
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
raise ValueError(
f"Unable to infer model provider for {model=}, please specify "
f"model_provider directly."
)
model_provider = model_provider.replace("-", "_").lower()
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
elif model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs)
elif model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
elif model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
elif model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
elif model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
elif model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
elif model_provider == "ollama":
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
return ChatOllama(model=model, **kwargs)
elif model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
elif model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs)
elif model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace(model_id=model, **kwargs)
elif model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
elif model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
# TODO: update to use model= once ChatBedrock supports
return ChatBedrock(model_id=model, **kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Unsupported {model_provider=}.\n\nSupported model providers are: "
f"{supported}"
)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
}
def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
if model_name.startswith("gpt-3") or model_name.startswith("gpt-4"):
return "openai"
elif model_name.startswith("claude"):
return "anthropic"
elif model_name.startswith("command"):
return "cohere"
elif model_name.startswith("accounts/fireworks"):
return "fireworks"
elif model_name.startswith("gemini"):
return "google_vertexai"
elif model_name.startswith("amazon."):
return "bedrock"
else:
return None
def _check_pkg(pkg: str) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
raise ImportError(
f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)

File diff suppressed because it is too large Load Diff

View File

@@ -22,15 +22,11 @@ PyYAML = ">=5.3"
numpy = "^1"
aiohttp = "^3.8.3"
tenacity = "^8.1.0"
async-timeout = {version = "^4.0.0", python = "<3.11"}
azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true}
openapi-pydantic = {version = "^0.3.2", optional = true}
faiss-cpu = {version = "^1", optional = true}
manifest-ml = {version = "^0.0.1", optional = true}
transformers = {version = "^4", optional = true}
beautifulsoup4 = {version = "^4", optional = true}
torch = {version = ">=1,<3", optional = true}
jinja2 = {version = "^3", optional = true}
tiktoken = {version = ">=0.7,<1.0", optional = true, python=">=3.9"}
qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
cohere = {version = ">=4,<6", optional = true}
@@ -38,77 +34,28 @@ openai = {version = "<2", optional = true}
nlpcloud = {version = "^1", optional = true}
huggingface_hub = {version = "^0", optional = true}
sentence-transformers = {version = "^2", optional = true}
arxiv = {version = "^1.4", optional = true}
pypdf = {version = "^3.4.0", optional = true}
aleph-alpha-client = {version="^2.15.0", optional = true}
pgvector = {version = "^0.1.6", optional = true}
async-timeout = {version = "^4.0.0", python = "<3.11"}
azure-identity = {version = "^1.12.0", optional=true}
atlassian-python-api = {version = "^3.36.0", optional=true}
html2text = {version="^2020.1.16", optional=true}
numexpr = {version="^2.8.6", optional=true}
azure-cosmos = {version="^4.4.0b1", optional=true}
jq = {version = "^1.4.1", optional = true}
pdfminer-six = {version = "^20221105", optional = true}
docarray = {version="^0.32.0", extras=["hnswlib"], optional=true}
lxml = {version = ">=4.9.3,<6.0", optional = true}
pymupdf = {version = "^1.22.3", optional = true}
rapidocr-onnxruntime = {version = "^1.3.2", optional = true, python = ">=3.8.1,<3.12"}
pypdfium2 = {version = "^4.10.0", optional = true}
gql = {version = "^3.4.1", optional = true}
pandas = {version = "^2.0.1", optional = true}
telethon = {version = "^1.28.5", optional = true}
chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true}
openlm = {version = "^0.0.5", optional = true}
scikit-learn = {version = "^1.2.2", optional = true}
azure-ai-formrecognizer = {version = "^3.2.1", optional = true}
azure-cognitiveservices-speech = {version = "^1.28.0", optional = true}
py-trello = {version = "^0.19.0", optional = true}
bibtexparser = {version = "^1.4.0", optional = true}
pyspark = {version = "^3.4.0", optional = true}
clarifai = {version = ">=9.1.0", optional = true}
mwparserfromhell = {version = "^0.6.4", optional = true}
mwxml = {version = "^0.3.3", optional = true}
azure-search-documents = {version = "11.4.0b8", optional = true}
esprima = {version = "^4.0.1", optional = true}
streamlit = {version = "^1.18.0", optional = true, python = ">=3.8.1,<3.9.7 || >3.9.7,<4.0"}
psychicapi = {version = "^0.8.0", optional = true}
cassio = {version = "^0.1.0", optional = true}
sympy = {version = "^1.12", optional = true}
rapidfuzz = {version = "^3.1.1", optional = true}
jsonschema = {version = ">1", optional = true}
rank-bm25 = {version = "^0.2.2", optional = true}
geopandas = {version = "^0.13.1", optional = true}
gitpython = {version = "^3.1.32", optional = true}
feedparser = {version = "^6.0.10", optional = true}
newspaper3k = {version = "^0.2.8", optional = true}
xata = {version = "^1.0.0a7", optional = true}
xmltodict = {version = "^0.13.0", optional = true}
markdownify = {version = "^0.11.6", optional = true}
assemblyai = {version = "^0.17.0", optional = true}
dashvector = {version = "^1.0.1", optional = true}
sqlite-vss = {version = "^0.1.2", optional = true}
motor = {version = "^3.3.1", optional = true}
timescale-vector = {version = "^0.0.1", optional = true}
typer = {version= "^0.9.0", optional = true}
anthropic = {version = "^0.3.11", optional = true}
aiosqlite = {version = "^0.19.0", optional = true}
rspace_client = {version = "^2.5.0", optional = true}
upstash-redis = {version = "^0.15.0", optional = true}
azure-ai-textanalytics = {version = "^5.3.0", optional = true}
google-cloud-documentai = {version = "^2.20.1", optional = true}
fireworks-ai = {version = "^0.9.0", optional = true}
javelin-sdk = {version = "^0.1.8", optional = true}
hologres-vector = {version = "^0.0.6", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
couchbase = {version = "^4.1.9", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
langchain-openai = {version = "^0.1", optional = true}
rdflib = {version = "7.0.0", optional = true}
langchain-openai = {version = "^0", optional = true}
langchain-anthropic = {version = "^0", optional = true}
langchain-fireworks = {version = "^0", optional = true}
langchain-together = {version = "^0", optional = true}
langchain-mistralai = {version = "^0", optional = true}
langchain-groq = {version = "^0", optional = true}
jsonschema = {version = "^4.22.0", optional = true}
[tool.poetry.group.test]
optional = true
@@ -162,11 +109,8 @@ optional = true
# https://python.langchain.com/docs/contributing/code#working-with-optional-dependencies
pytest-vcr = "^1.0.2"
wrapt = "^1.15.0"
openai = "^1"
python-dotenv = "^1.0.0"
cassio = "^0.1.0"
tiktoken = ">=0.7,<1"
anthropic = "^0.3.11"
langchain-core = {path = "../core", develop = true}
langchain-text-splitters = {path = "../text-splitters", develop = true}
langchainhub = "^0.1.16"
@@ -229,74 +173,18 @@ cli = ["typer"]
# Please use new-line on formatting to make it easier to add new packages without
# merge-conflicts
extended_testing = [
"aleph-alpha-client",
"aiosqlite",
"assemblyai",
"beautifulsoup4",
"bibtexparser",
"cassio",
"chardet",
"datasets",
"google-cloud-documentai",
"esprima",
"jq",
"pdfminer-six",
"pgvector",
"pypdf",
"pymupdf",
"pypdfium2",
"tqdm",
"lxml",
"atlassian-python-api",
"mwparserfromhell",
"mwxml",
"msal",
"pandas",
"telethon",
"psychicapi",
"gql",
"requests-toolbelt",
"html2text",
"numexpr",
"py-trello",
"scikit-learn",
"streamlit",
"pyspark",
"openai",
"sympy",
"rapidfuzz",
"jsonschema",
"openai",
"rank-bm25",
"geopandas",
"jinja2",
"gitpython",
"newspaper3k",
"feedparser",
"xata",
"xmltodict",
"faiss-cpu",
"openapi-pydantic",
"markdownify",
"arxiv",
"dashvector",
"sqlite-vss",
"rapidocr-onnxruntime",
"motor",
"timescale-vector",
"anthropic",
"upstash-redis",
"rspace_client",
"fireworks-ai",
"javelin-sdk",
"hologres-vector",
"praw",
"databricks-vectorsearch",
"couchbase",
"dgml-utils",
"cohere",
"langchain-openai",
"rdflib",
"langchain-openai",
"langchain-anthropic",
"langchain-fireworks",
"langchain-together",
"langchain-mistralai",
"langchain-groq",
"openai",
"tiktoken",
"numexpr",
"rapidfuzz",
"aiosqlite",
"jsonschema",
]
[tool.ruff]

View File

@@ -1,12 +1,47 @@
from langchain.chat_models.base import __all__
import pytest
from langchain.chat_models.base import __all__, init_chat_model
EXPECTED_ALL = [
"BaseChatModel",
"SimpleChatModel",
"agenerate_from_stream",
"generate_from_stream",
"init_chat_model",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)
@pytest.mark.requires(
"langchain_openai",
"langchain_anthropic",
"langchain_fireworks",
"langchain_together",
"langchain_mistralai",
"langchain_groq",
)
@pytest.mark.parametrize(
["model_name", "model_provider"],
[
("gpt-4o", "openai"),
("claude-3-opus-20240229", "anthropic"),
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
("meta-llama/Llama-3-8b-chat-hf", "together"),
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: str) -> None:
init_chat_model(model_name, model_provider=model_provider, api_key="foo")
def test_init_missing_dep() -> None:
with pytest.raises(ImportError):
init_chat_model("gpt-4o", model_provider="openai")
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError):
init_chat_model("foo", model_provider="bar")

View File

@@ -1,6 +1,7 @@
from langchain import chat_models
EXPECTED_ALL = [
"init_chat_model",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",