Compare commits

...

3 Commits

Author SHA1 Message Date
Harrison Chase
39bd4c7fe4 cr 2023-08-27 11:15:45 -07:00
Harrison Chase
60818aeaaa add llm kwargs 2023-08-27 11:11:13 -07:00
Harrison Chase
86a3e51440 stash 2023-08-27 10:49:00 -07:00
7 changed files with 61 additions and 18 deletions

View File

@@ -212,6 +212,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
messages,
functions=self.functions,
callbacks=callbacks,
**prompt.llm_kwargs,
)
else:
predicted_message = self.llm.predict_messages(
@@ -245,7 +246,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -267,7 +267,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@@ -296,7 +296,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -160,7 +160,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
prompt = self._convert_input(input)
messages = prompt.to_messages()
# kwargs from prompt defer to kwargs passed in
kwargs = {**prompt.llm_kwargs, **kwargs}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
@@ -207,7 +210,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
prompt = self._convert_input(input)
messages = prompt.to_messages()
# prompt kwargs defer to kwargs passed in
kwargs = {**prompt.llm_kwargs, **kwargs}
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
@@ -410,6 +416,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
@@ -420,6 +433,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs

View File

@@ -334,7 +334,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
prompt_value = self._convert_input(input)
prompt = prompt_value.to_string()
kwargs = {**prompt_value.llm_kwargs, **kwargs}
config = config or {}
params = self.dict()
params["stop"] = stop
@@ -381,7 +383,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
prompt_value = self._convert_input(input)
prompt = prompt_value.to_string()
kwargs = {**prompt_value.llm_kwargs, **kwargs}
config = config or {}
params = self.dict()
params["stop"] = stop
@@ -463,6 +467,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
@@ -473,6 +484,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
llm_kwargs = prompts[0].llm_kwargs
for prompt in prompts:
if prompt.llm_kwargs != llm_kwargs:
raise ValueError(
"All prompt kwargs must be the same when calling in batch"
)
kwargs = {**llm_kwargs, **kwargs}
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(
prompt_strings, stop=stop, callbacks=callbacks, **kwargs

View File

@@ -299,6 +299,8 @@ class ChatPromptValue(PromptValue):
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
llm_kwargs: dict = Field(default_factory=dict)
@property
def lc_attributes(self) -> Dict:
"""
@@ -330,7 +332,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
return ChatPromptValue(messages=messages, llm_kwargs=self.llm_kwargs)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
@@ -487,8 +489,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
cls, messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@@ -534,7 +535,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
return cls(input_variables=sorted(input_vars), messages=_messages, **kwargs)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.

View File

@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from typing import List
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.schema.messages import BaseMessage
@@ -14,6 +15,8 @@ class PromptValue(Serializable, ABC):
ChatModel inputs.
"""
llm_kwargs: dict = Field(default_factory=dict)
@property
def lc_serializable(self) -> bool:
"""

File diff suppressed because one or more lines are too long