diff --git a/libs/core/langchain_core/__init__.py b/libs/core/langchain_core/__init__.py index 0326da7c0de..62d1791108d 100644 --- a/libs/core/langchain_core/__init__.py +++ b/libs/core/langchain_core/__init__.py @@ -14,6 +14,15 @@ from langchain_core._api import ( surface_langchain_beta_warnings, surface_langchain_deprecation_warnings, ) +from langchain_core._api.shorthand import ( + chat_model, + document_loader, + embeddings, + retriever, + tool, + toolkit, + vectorstore, +) try: __version__ = metadata.version(__package__) @@ -23,3 +32,14 @@ except metadata.PackageNotFoundError: surface_langchain_deprecation_warnings() surface_langchain_beta_warnings() + +__all__ = [ + "__version__", + "chat_model", + "retriever", + "tool", + "toolkit", + "document_loader", + "vectorstore", + "embeddings", +] diff --git a/libs/core/langchain_core/_api/shorthand.py b/libs/core/langchain_core/_api/shorthand.py new file mode 100644 index 00000000000..fedfc8ebcc1 --- /dev/null +++ b/libs/core/langchain_core/_api/shorthand.py @@ -0,0 +1,93 @@ +""" +Functions that support shorthand access of LangChain classes like + +::code + import langchain_core as lc + + model = lc.chat_model("claude-3-5-sonnet-20240620", provider="anthropic") +""" + +import importlib +from typing import TYPE_CHECKING, Literal, cast + +if TYPE_CHECKING: + from langchain_core.document_loaders import BaseLoader + from langchain_core.embeddings import Embeddings + from langchain_core.language_models import BaseChatModel + from langchain_core.retrievers import BaseRetriever + from langchain_core.tools import BaseTool, BaseToolkit + from langchain_core.vectorstores import VectorStore + +ProviderManifestType = dict[ + Literal[ + "chat_model", + "retriever", + "embeddings", + "vectorstore", + "tool", + "toolkit", + "document_loader", + ], + list[str], +] +_registered_providers: dict[ + str, + ProviderManifestType, +] = {} + + +def register_manifest( + provider: str, + provider_manifest: ProviderManifestType, + *, + dangerously_allow_provider_overwrite: bool = False, +) -> None: + if provider in _registered_providers and not dangerously_allow_provider_overwrite: + msg = ( + f"Provider `{provider}` was already registered. To allow overwriting, " + "pass `dangerously_allow_provider_overwrite=True`" + ) + raise ValueError(msg) + _registered_providers[provider] = provider_manifest + + +def register(package_root: str) -> None: + manifest_module = importlib.import_module("lc_manifest", package_root) + if not hasattr(manifest_module, "manifest"): + msg = ( + f"Package registration requires a {package_root}.lc_manifest.manifest " + "dictionary to be declared" + ) + raise ValueError(msg) + + manifest = cast(dict[str, ProviderManifestType], manifest_module.manifest) + for provider, provider_manifest in manifest.items(): + register_manifest(provider, provider_manifest) + + +def chat_model(model: str, *, provider: str | None = None, **kwargs) -> BaseChatModel: + pass + + +def retriever(name: str, **kwargs) -> BaseRetriever: + pass + + +def tool(name: str, **kwargs) -> BaseTool: + pass + + +def toolkit(name: str, **kwargs) -> BaseToolkit: + pass + + +def document_loader(name: str, **kwargs) -> BaseLoader: + pass + + +def vectorstore(name: str, **kwargs) -> VectorStore: + pass + + +def embeddings(model: str, *, provider: str | None = None, **kwargs) -> Embeddings: + pass diff --git a/libs/partners/openai/langchain_openai/lc_manifest.py b/libs/partners/openai/langchain_openai/lc_manifest.py new file mode 100644 index 00000000000..950cd558dec --- /dev/null +++ b/libs/partners/openai/langchain_openai/lc_manifest.py @@ -0,0 +1,36 @@ +from typing import Literal + +# we want to register integration classes, as well as provide a hook + +manifest: dict[ + str, + dict[ + Literal[ + "chat_model", + "retriever", + "embeddings", + "vectorstore", + "tool", + "toolkit", + "document_loader", + ], + list[str], + ], +] = { + "openai": { + "chat_model": [ + "langchain_openai.chat_models.base.ChatOpenAI", + ], + "embeddings": [ + "langchain_openai.embeddings.base.OpenAIEmbeddings", + ], + }, + "azure": { + "chat_model": [ + "langchain_openai.chat_models.azure.AzureChatOpenAI", + ], + "embeddings": [ + "langchain_openai.embeddings.azure.AzureOpenAIEmbeddings", + ], + }, +}