From 63916cfe3556788bc015ab9b9456e7087823ff2d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 26 Dec 2023 12:19:50 -0800 Subject: [PATCH] [core] langauge model like (#15180) --- libs/core/langchain_core/language_models/__init__.py | 2 ++ libs/core/langchain_core/language_models/base.py | 8 +++++--- .../core/tests/unit_tests/language_models/test_imports.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/libs/core/langchain_core/language_models/__init__.py b/libs/core/langchain_core/language_models/__init__.py index 4c2e597dddf..b86de1cdb13 100644 --- a/libs/core/langchain_core/language_models/__init__.py +++ b/libs/core/langchain_core/language_models/__init__.py @@ -1,6 +1,7 @@ from langchain_core.language_models.base import ( BaseLanguageModel, LanguageModelInput, + LanguageModelLike, LanguageModelOutput, get_tokenizer, ) @@ -16,4 +17,5 @@ __all__ = [ "LanguageModelInput", "get_tokenizer", "LanguageModelOutput", + "LanguageModelLike", ] diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index f48f83b37f4..b590e0df84f 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -17,7 +17,7 @@ from typing_extensions import TypeAlias from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string from langchain_core.prompt_values import PromptValue -from langchain_core.runnables import RunnableSerializable +from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names if TYPE_CHECKING: @@ -49,11 +49,13 @@ def _get_token_ids_default_method(text: str) -> List[int]: LanguageModelInput = Union[PromptValue, str, List[BaseMessage]] -LanguageModelOutput = TypeVar("LanguageModelOutput") +LanguageModelOutput = Union[BaseMessage, str] +LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] +LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) class BaseLanguageModel( - RunnableSerializable[LanguageModelInput, LanguageModelOutput], ABC + RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC ): """Abstract base class for interfacing with language models. diff --git a/libs/core/tests/unit_tests/language_models/test_imports.py b/libs/core/tests/unit_tests/language_models/test_imports.py index 348111f8511..65d627c56fd 100644 --- a/libs/core/tests/unit_tests/language_models/test_imports.py +++ b/libs/core/tests/unit_tests/language_models/test_imports.py @@ -9,6 +9,7 @@ EXPECTED_ALL = [ "LanguageModelInput", "LanguageModelOutput", "get_tokenizer", + "LanguageModelLike", ]