mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 22:56:05 +00:00
stash
This commit is contained in:
@@ -212,6 +212,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
messages,
|
||||
functions=self.functions,
|
||||
callbacks=callbacks,
|
||||
**prompt.llm_kwargs,
|
||||
)
|
||||
else:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
@@ -245,7 +246,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@@ -267,7 +267,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
@@ -296,7 +296,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
messages, functions=self.functions, callbacks=callbacks, **prompt.llm_kwargs
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@@ -160,7 +160,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
prompt = self._convert_input(input)
|
||||
messages = prompt.to_messages()
|
||||
# kwargs from prompt defer to kwargs passed in
|
||||
kwargs = {**prompt.llm_kwargs, **kwargs}
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = CallbackManager.configure(
|
||||
@@ -207,7 +210,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
config = config or {}
|
||||
messages = self._convert_input(input).to_messages()
|
||||
prompt = self._convert_input(input)
|
||||
messages = prompt.to_messages()
|
||||
# prompt kwargs defer to kwargs passed in
|
||||
kwargs = {**prompt.llm_kwargs, **kwargs}
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
@@ -410,6 +416,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
llm_kwargs = prompts[0].llm_kwargs
|
||||
for prompt in prompts:
|
||||
if prompt.llm_kwargs != llm_kwargs:
|
||||
raise ValueError("All prompt kwargs must be the same when calling in batch")
|
||||
kwargs = {**llm_kwargs, **kwargs}
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
@@ -420,6 +431,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
llm_kwargs = prompts[0].llm_kwargs
|
||||
for prompt in prompts:
|
||||
if prompt.llm_kwargs != llm_kwargs:
|
||||
raise ValueError("All prompt kwargs must be the same when calling in batch")
|
||||
kwargs = {**llm_kwargs, **kwargs}
|
||||
prompt_messages = [p.to_messages() for p in prompts]
|
||||
return await self.agenerate(
|
||||
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
||||
|
||||
@@ -334,7 +334,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
prompt_value = self._convert_input(input)
|
||||
prompt = prompt_value.to_string()
|
||||
kwargs = {**prompt_value.llm_kwargs, **kwargs}
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -381,7 +383,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
prompt = self._convert_input(input).to_string()
|
||||
prompt_value = self._convert_input(input)
|
||||
prompt = prompt_value.to_string()
|
||||
kwargs = {**prompt_value.llm_kwargs, **kwargs}
|
||||
config = config or {}
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
@@ -463,6 +467,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
llm_kwargs = prompts[0].llm_kwargs
|
||||
for prompt in prompts:
|
||||
if prompt.llm_kwargs != llm_kwargs:
|
||||
raise ValueError("All prompt kwargs must be the same when calling in batch")
|
||||
kwargs = {**llm_kwargs, **kwargs}
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
|
||||
|
||||
@@ -473,6 +482,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
llm_kwargs = prompts[0].llm_kwargs
|
||||
for prompt in prompts:
|
||||
if prompt.llm_kwargs != llm_kwargs:
|
||||
raise ValueError("All prompt kwargs must be the same when calling in batch")
|
||||
kwargs = {**llm_kwargs, **kwargs}
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
return await self.agenerate(
|
||||
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
|
||||
|
||||
@@ -298,6 +298,7 @@ class ChatPromptValue(PromptValue):
|
||||
|
||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
llm_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
@@ -330,7 +331,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
PromptValue.
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return ChatPromptValue(messages=messages)
|
||||
return ChatPromptValue(messages=messages, llm_kwargs=self.llm_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
@@ -489,6 +490,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
**kwargs: Any
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@@ -534,7 +536,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(input_variables=sorted(input_vars), messages=_messages)
|
||||
return cls(input_variables=sorted(input_vars), messages=_messages, **kwargs)
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
@@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain.pydantic_v1 import Field
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
@@ -13,6 +15,7 @@ class PromptValue(Serializable, ABC):
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
llm_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user