Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
5e90a832c3 undo 2024-02-07 10:40:23 -08:00
Bagatur
4b5a2b6b00 rfc: return full llm res 2024-02-07 10:39:26 -08:00

View File

@@ -11,9 +11,12 @@ from typing import (
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Union,
cast,
overload,
)
from langchain_core._api import deprecated
@@ -152,27 +155,55 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"Must be a PromptValue, str, or list of BaseMessages."
)
@overload
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
return_type: Literal["message"] = "message",
**kwargs: Any,
) -> BaseMessage:
...
@overload
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
return_type: Literal["full"],
**kwargs: Any,
) -> LLMResult:
...
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
return_type: Literal["message", "full"] = "message",
**kwargs: Any,
) -> Union[BaseMessage, LLMResult]:
config = ensure_config(config)
return cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
).generations[0][0],
).message
llm_result = self.generate_prompt(
[self._convert_input(input)],
stop=stop,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
if return_type == "full":
return llm_result
elif return_type == "message":
return cast(ChatGeneration, llm_result.generations[0][0]).message
else:
raise ValueError
async def ainvoke(
self,