From 86a3e5144014a2b52cd67d2bfd0b8ecfa6afe715 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 27 Aug 2023 10:49:00 -0700 Subject: [PATCH] stash --- .../agents/openai_functions_agent/base.py | 3 ++- .../openai_functions_multi_agent/base.py | 4 ++-- libs/langchain/langchain/chat_models/base.py | 20 +++++++++++++++++-- libs/langchain/langchain/llms/base.py | 18 +++++++++++++++-- libs/langchain/langchain/prompts/chat.py | 6 ++++-- libs/langchain/langchain/schema/prompt.py | 3 +++ 6 files changed, 45 insertions(+), 9 deletions(-) diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 19d5ebbc433..dab1a1d6647 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -212,6 +212,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): messages, functions=self.functions, callbacks=callbacks, + **prompt.llm_kwargs, ) else: predicted_message = self.llm.predict_messages( @@ -245,7 +246,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs ) agent_decision = _parse_ai_message(predicted_message) return agent_decision diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index fcc51227fda..581fe0c20a7 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -267,7 +267,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs ) agent_decision = _parse_ai_message(predicted_message) return agent_decision @@ -296,7 +296,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): prompt = self.prompt.format_prompt(**full_inputs) messages = prompt.to_messages() predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs ) agent_decision = _parse_ai_message(predicted_message) return agent_decision diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 2d0db37c0ac..f4f4850a0ac 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -160,7 +160,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): yield self.invoke(input, config=config, stop=stop, **kwargs) else: config = config or {} - messages = self._convert_input(input).to_messages() + prompt = self._convert_input(input) + messages = prompt.to_messages() + # kwargs from prompt defer to kwargs passed in + kwargs = {**prompt.llm_kwargs, **kwargs} params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} callback_manager = CallbackManager.configure( @@ -207,7 +210,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): yield self.invoke(input, config=config, stop=stop, **kwargs) else: config = config or {} - messages = self._convert_input(input).to_messages() + prompt = self._convert_input(input) + messages = prompt.to_messages() + # prompt kwargs defer to kwargs passed in + kwargs = {**prompt.llm_kwargs, **kwargs} params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} callback_manager = AsyncCallbackManager.configure( @@ -410,6 +416,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: + llm_kwargs = prompts[0].llm_kwargs + for prompt in prompts: + if prompt.llm_kwargs != llm_kwargs: + raise ValueError("All prompt kwargs must be the same when calling in batch") + kwargs = {**llm_kwargs, **kwargs} prompt_messages = [p.to_messages() for p in prompts] return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) @@ -420,6 +431,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): callbacks: Callbacks = None, **kwargs: Any, ) -> LLMResult: + llm_kwargs = prompts[0].llm_kwargs + for prompt in prompts: + if prompt.llm_kwargs != llm_kwargs: + raise ValueError("All prompt kwargs must be the same when calling in batch") + kwargs = {**llm_kwargs, **kwargs} prompt_messages = [p.to_messages() for p in prompts] return await self.agenerate( prompt_messages, stop=stop, callbacks=callbacks, **kwargs diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index a91ecd9f2ac..95c2b8d3607 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -334,7 +334,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): # model doesn't implement streaming, so use default implementation yield self.invoke(input, config=config, stop=stop, **kwargs) else: - prompt = self._convert_input(input).to_string() + prompt_value = self._convert_input(input) + prompt = prompt_value.to_string() + kwargs = {**prompt_value.llm_kwargs, **kwargs} config = config or {} params = self.dict() params["stop"] = stop @@ -381,7 +383,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): # model doesn't implement streaming, so use default implementation yield await self.ainvoke(input, config=config, stop=stop, **kwargs) else: - prompt = self._convert_input(input).to_string() + prompt_value = self._convert_input(input) + prompt = prompt_value.to_string() + kwargs = {**prompt_value.llm_kwargs, **kwargs} config = config or {} params = self.dict() params["stop"] = stop @@ -463,6 +467,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: + llm_kwargs = prompts[0].llm_kwargs + for prompt in prompts: + if prompt.llm_kwargs != llm_kwargs: + raise ValueError("All prompt kwargs must be the same when calling in batch") + kwargs = {**llm_kwargs, **kwargs} prompt_strings = [p.to_string() for p in prompts] return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) @@ -473,6 +482,11 @@ class BaseLLM(BaseLanguageModel[str], ABC): callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: + llm_kwargs = prompts[0].llm_kwargs + for prompt in prompts: + if prompt.llm_kwargs != llm_kwargs: + raise ValueError("All prompt kwargs must be the same when calling in batch") + kwargs = {**llm_kwargs, **kwargs} prompt_strings = [p.to_string() for p in prompts] return await self.agenerate( prompt_strings, stop=stop, callbacks=callbacks, **kwargs diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index 7e7e5809f31..ef163bd6377 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -298,6 +298,7 @@ class ChatPromptValue(PromptValue): class BaseChatPromptTemplate(BasePromptTemplate, ABC): """Base class for chat prompt templates.""" + llm_kwargs: dict = Field(default_factory=dict) @property def lc_attributes(self) -> Dict: @@ -330,7 +331,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): PromptValue. """ messages = self.format_messages(**kwargs) - return ChatPromptValue(messages=messages) + return ChatPromptValue(messages=messages, llm_kwargs=self.llm_kwargs) @abstractmethod def format_messages(self, **kwargs: Any) -> List[BaseMessage]: @@ -489,6 +490,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): def from_messages( cls, messages: Sequence[MessageLikeRepresentation], + **kwargs: Any ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -534,7 +536,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): ): input_vars.update(_message.input_variables) - return cls(input_variables=sorted(input_vars), messages=_messages) + return cls(input_variables=sorted(input_vars), messages=_messages, **kwargs) def format(self, **kwargs: Any) -> str: """Format the chat template into a string. diff --git a/libs/langchain/langchain/schema/prompt.py b/libs/langchain/langchain/schema/prompt.py index 951954c5f2e..10d5759ed70 100644 --- a/libs/langchain/langchain/schema/prompt.py +++ b/libs/langchain/langchain/schema/prompt.py @@ -3,6 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import List +from langchain.pydantic_v1 import Field + from langchain.load.serializable import Serializable from langchain.schema.messages import BaseMessage @@ -13,6 +15,7 @@ class PromptValue(Serializable, ABC): PromptValues can be converted to both LLM (pure text-generation) inputs and ChatModel inputs. """ + llm_kwargs: dict = Field(default_factory=dict) @property def lc_serializable(self) -> bool: