mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-19 01:21:50 +00:00
Merge 5dcb5cb987
into 3a487bf720
This commit is contained in:
commit
3ddfd1f79d
@ -51,7 +51,12 @@ from langchain_core.outputs import (
|
||||
ChatResult,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
@ -714,6 +719,34 @@ class ChatHuggingFace(BaseChatModel):
|
||||
)
|
||||
yield generation_chunk
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> list[BaseMessage]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
if _is_huggingface_pipeline(self.llm):
|
||||
prompt_messages = [
|
||||
self._convert_input(input_).to_messages() for input_ in inputs
|
||||
]
|
||||
|
||||
llm_inputs = list(map(self._to_chat_prompt, prompt_messages))
|
||||
|
||||
llm_results = self.llm._generate(prompts=llm_inputs)
|
||||
chat_result = self._to_chat_result(llm_results)
|
||||
return [gen.message for gen in chat_result.generations]
|
||||
return super().batch(
|
||||
inputs=inputs,
|
||||
config=config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _to_chat_prompt(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -750,10 +783,10 @@ class ChatHuggingFace(BaseChatModel):
|
||||
@staticmethod
|
||||
def _to_chat_result(llm_result: LLMResult) -> ChatResult:
|
||||
chat_generations = []
|
||||
|
||||
for g in llm_result.generations[0]:
|
||||
for g in llm_result.generations:
|
||||
chat_generation = ChatGeneration(
|
||||
message=AIMessage(content=g.text), generation_info=g.generation_info
|
||||
message=AIMessage(content=g[0].text),
|
||||
generation_info=g[0].generation_info,
|
||||
)
|
||||
chat_generations.append(chat_generation)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user