This commit is contained in:
Erick Friis 2024-11-19 16:16:57 -08:00
parent 6bda89f9a1
commit fa2f383107
3 changed files with 149 additions and 0 deletions

View File

@ -14,6 +14,15 @@ from langchain_core._api import (
surface_langchain_beta_warnings,
surface_langchain_deprecation_warnings,
)
from langchain_core._api.shorthand import (
chat_model,
document_loader,
embeddings,
retriever,
tool,
toolkit,
vectorstore,
)
try:
__version__ = metadata.version(__package__)
@ -23,3 +32,14 @@ except metadata.PackageNotFoundError:
surface_langchain_deprecation_warnings()
surface_langchain_beta_warnings()
__all__ = [
"__version__",
"chat_model",
"retriever",
"tool",
"toolkit",
"document_loader",
"vectorstore",
"embeddings",
]

View File

@ -0,0 +1,93 @@
"""
Functions that support shorthand access of LangChain classes like
::code
import langchain_core as lc
model = lc.chat_model("claude-3-5-sonnet-20240620", provider="anthropic")
"""
import importlib
from typing import TYPE_CHECKING, Literal, cast
if TYPE_CHECKING:
from langchain_core.document_loaders import BaseLoader
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import BaseTool, BaseToolkit
from langchain_core.vectorstores import VectorStore
ProviderManifestType = dict[
Literal[
"chat_model",
"retriever",
"embeddings",
"vectorstore",
"tool",
"toolkit",
"document_loader",
],
list[str],
]
_registered_providers: dict[
str,
ProviderManifestType,
] = {}
def register_manifest(
provider: str,
provider_manifest: ProviderManifestType,
*,
dangerously_allow_provider_overwrite: bool = False,
) -> None:
if provider in _registered_providers and not dangerously_allow_provider_overwrite:
msg = (
f"Provider `{provider}` was already registered. To allow overwriting, "
"pass `dangerously_allow_provider_overwrite=True`"
)
raise ValueError(msg)
_registered_providers[provider] = provider_manifest
def register(package_root: str) -> None:
manifest_module = importlib.import_module("lc_manifest", package_root)
if not hasattr(manifest_module, "manifest"):
msg = (
f"Package registration requires a {package_root}.lc_manifest.manifest "
"dictionary to be declared"
)
raise ValueError(msg)
manifest = cast(dict[str, ProviderManifestType], manifest_module.manifest)
for provider, provider_manifest in manifest.items():
register_manifest(provider, provider_manifest)
def chat_model(model: str, *, provider: str | None = None, **kwargs) -> BaseChatModel:
pass
def retriever(name: str, **kwargs) -> BaseRetriever:
pass
def tool(name: str, **kwargs) -> BaseTool:
pass
def toolkit(name: str, **kwargs) -> BaseToolkit:
pass
def document_loader(name: str, **kwargs) -> BaseLoader:
pass
def vectorstore(name: str, **kwargs) -> VectorStore:
pass
def embeddings(model: str, *, provider: str | None = None, **kwargs) -> Embeddings:
pass

View File

@ -0,0 +1,36 @@
from typing import Literal
# we want to register integration classes, as well as provide a hook
manifest: dict[
str,
dict[
Literal[
"chat_model",
"retriever",
"embeddings",
"vectorstore",
"tool",
"toolkit",
"document_loader",
],
list[str],
],
] = {
"openai": {
"chat_model": [
"langchain_openai.chat_models.base.ChatOpenAI",
],
"embeddings": [
"langchain_openai.embeddings.base.OpenAIEmbeddings",
],
},
"azure": {
"chat_model": [
"langchain_openai.chat_models.azure.AzureChatOpenAI",
],
"embeddings": [
"langchain_openai.embeddings.azure.AzureOpenAIEmbeddings",
],
},
}