Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase
ee8203859c add kwargs for chat models 2023-06-08 00:52:38 -07:00
7 changed files with 96 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Set
from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel
@@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text."""
@abstractmethod
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Predict message from messages."""
@abstractmethod
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text."""
@abstractmethod
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Predict message from messages."""

View File

@@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
@@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop

View File

@@ -64,11 +64,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params.update(kwargs)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
@@ -82,9 +84,9 @@ class BaseChatModel(BaseLanguageModel, ABC):
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._generate(m, stop=stop)
else self._generate(m, stop=stop, **kwargs)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
@@ -103,10 +105,12 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
params.update(kwargs)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
@@ -121,9 +125,9 @@ class BaseChatModel(BaseLanguageModel, ABC):
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._agenerate(m, stop=stop)
else self._agenerate(m, stop=stop, **kwargs)
for m in messages
]
)
@@ -143,18 +147,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
@abstractmethod
def _generate(
@@ -162,6 +170,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Top Level call"""
@@ -171,6 +180,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Top Level call"""
@@ -179,9 +189,10 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any
) -> BaseMessage:
generation = self.generate(
[messages], stop=stop, callbacks=callbacks
[messages], stop=stop, callbacks=callbacks, **kwargs
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
@@ -193,50 +204,69 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any
) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
return self.predict(message, stop=stop)
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
return result.content
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop)
return self(messages, stop=_stop, **kwargs)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
return result.content
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
def _identifying_params(self) -> Mapping[str, Any]:
@@ -261,8 +291,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager)
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@@ -273,6 +304,7 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> str:
"""Simpler interface."""
@@ -281,6 +313,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

View File

@@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
@@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
@@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
@@ -311,6 +314,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)

View File

@@ -302,8 +302,9 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
if self.streaming:
inner_completion = ""
role = "assistant"
@@ -324,13 +325,14 @@ class ChatOpenAI(BaseChatModel):
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
self, messages: List[BaseMessage], stop: Optional[List[str]], **kwargs: Any
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params.update(kwargs)
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
@@ -348,8 +350,9 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages, stop, **kwargs)
if self.streaming:
inner_completion = ""
role = "assistant"

View File

@@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
@@ -52,7 +53,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):
response_dict, params = super()._create_message_dicts(
[generation.message], stop
[generation.message], stop, **kwargs
)
pl_request_id = promptlayer_api_request(
"langchain.PromptLayerChatOpenAI",
@@ -79,6 +80,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async
@@ -89,7 +91,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):
response_dict, params = super()._create_message_dicts(
[generation.message], stop
[generation.message], stop, **kwargs
)
pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerChatOpenAI.async",

View File

@@ -1,6 +1,6 @@
"""Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import root_validator
@@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
@@ -122,7 +123,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
chat = self.client.start_chat(context=context, **self._default_params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params)
response = chat.send_message(question.content, **self._default_params, **kwargs)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
@@ -131,6 +132,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""