mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 09:48:04 +00:00
Update chat model output type (#11833)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ed62984cb2
commit
7db6aabf65
@ -11,7 +11,6 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@ -38,12 +37,10 @@ from langchain.schema import (
|
||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
@ -79,7 +76,7 @@ async def _agenerate_from_stream(
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
"""Base class for Chat models."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
@ -116,9 +113,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
"""Get the output type for this runnable."""
|
||||
return Union[
|
||||
HumanMessage, AIMessage, ChatMessage, FunctionMessage, SystemMessage
|
||||
]
|
||||
return AnyMessage
|
||||
|
||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||
if isinstance(input, PromptValue):
|
||||
@ -140,23 +135,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
return cast(
|
||||
BaseMessageChunk,
|
||||
cast(
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message,
|
||||
)
|
||||
ChatGeneration,
|
||||
self.generate_prompt(
|
||||
[self._convert_input(input)],
|
||||
stop=stop,
|
||||
callbacks=config.get("callbacks"),
|
||||
tags=config.get("tags"),
|
||||
metadata=config.get("metadata"),
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
).generations[0][0],
|
||||
).message
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -165,7 +157,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessageChunk:
|
||||
) -> BaseMessage:
|
||||
config = config or {}
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input)],
|
||||
@ -176,9 +168,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
run_name=config.get("run_name"),
|
||||
**kwargs,
|
||||
)
|
||||
return cast(
|
||||
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
)
|
||||
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -190,7 +180,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if type(self)._stream == BaseChatModel._stream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
@ -241,7 +233,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if type(self)._astream == BaseChatModel._astream:
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
|
@ -2163,19 +2163,19 @@
|
||||
dict({
|
||||
'anyOf': list([
|
||||
dict({
|
||||
'$ref': '#/definitions/HumanMessage',
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/AIMessage',
|
||||
'$ref': '#/definitions/HumanMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/ChatMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/FunctionMessage',
|
||||
'$ref': '#/definitions/SystemMessage',
|
||||
}),
|
||||
dict({
|
||||
'$ref': '#/definitions/SystemMessage',
|
||||
'$ref': '#/definitions/FunctionMessage',
|
||||
}),
|
||||
]),
|
||||
'definitions': dict({
|
||||
|
Loading…
Reference in New Issue
Block a user