Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
f0a7fab51b rfc 2023-10-24 20:39:45 -07:00
24 changed files with 376 additions and 292 deletions

View File

@@ -160,30 +160,31 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
) -> List[ChatResult]:
if self.streaming and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
prompt = self._convert_messages_to_prompt(
messages,
)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
return [_generate_from_stream(stream_iter)]
results = []
for msgs_prompt in messages:
prompt = self._convert_messages_to_prompt(msgs_prompt)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
return results
async def _agenerate(
self,

View File

@@ -195,25 +195,26 @@ class ChatBaichuan(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
) -> List[ChatResult]:
if self.streaming and len(messages) == 1:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
res = self._chat(messages, **kwargs)
results = []
for msgs_prompt in messages:
res = self._chat(msgs_prompt, **kwargs)
response = res.json()
if response.get("code") != 0:
raise ValueError(f"Error from Baichuan api response: {response}")
response = res.json()
if response.get("code") != 0:
raise ValueError(f"Error from Baichuan api response: {response}")
return self._create_chat_result(response)
results.append(self._create_chat_result(response))
return results
def _stream(
self,

View File

@@ -221,11 +221,11 @@ class QianfanChatEndpoint(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
messages: The messages to pass into the model.
@@ -237,32 +237,37 @@ class QianfanChatEndpoint(BaseChatModel):
.. code-block:: python
response = qianfan_model("Tell me a joke.")
"""
if self.streaming:
if self.streaming and len(messages) == 1:
completion = ""
for chunk in self._stream(messages, stop, run_manager, **kwargs):
for chunk in self._stream(messages[0], stop, run_manager, **kwargs):
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
message=lc_msg,
generation_info=dict(finish_reason="stop"),
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
return [
ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
)
]
results = []
for msgs_prompt in messages:
params = self._convert_prompt_msg_params(msgs_prompt, **kwargs)
response_payload = self.client.do(**params)
lc_msg = _convert_dict_to_message(response_payload)
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params)
lc_msg = _convert_dict_to_message(response_payload)
gen = ChatGeneration(
message=lc_msg,
generation_info={
"finish_reason": "stop",
**response_payload.get("body", {}),
},
)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=[gen], llm_output=llm_output)
token_usage = response_payload.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
results.append(ChatResult(generations=[gen], llm_output=llm_output))
return results
async def _agenerate(
self,

View File

@@ -472,11 +472,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _generate_with_cache(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
@@ -496,19 +496,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
return self._generate(messages, stop=stop, **kwargs)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
results = []
to_generate = []
for i, msgs_prompt in enumerate(messages):
prompt = dumps(msgs_prompt)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
results.append((i, ChatResult(generations=cache_val)))
else:
to_generate.append((i, msgs_prompt))
if to_generate:
generate_idxs, generate_messages = tuple(zip(*to_generate))
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
generations = self._generate(
list(generate_messages),
stop=stop,
run_manager=run_manager,
**kwargs,
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
generations = self._generate(
list(generate_messages), stop=stop, **kwargs
)
for msgs_prompt, chat_res in zip(generate_messages, generations):
prompt = dumps(msgs_prompt)
llm_cache.update(prompt, llm_string, chat_res.generations)
results.extend(zip(generate_idxs, generations))
return [res for _, res in sorted(results)]
async def _agenerate_with_cache(
self,
@@ -553,20 +567,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Top Level call"""
async def _agenerate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Top Level call"""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
@@ -696,15 +710,20 @@ class SimpleChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
) -> List[ChatResult]:
results = []
for msgs_prompt in messages:
output_str = self._call(
msgs_prompt, stop=stop, run_manager=run_manager, **kwargs
)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
results.append(ChatResult(generations=[generation]))
return results
@abstractmethod
def _call(

View File

@@ -66,21 +66,22 @@ class BedrockChat(BaseChatModel, BedrockBase):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
completion = ""
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
if self.streaming and len(messages) == 1:
for chunk in self._stream(messages[0], stop, run_manager, **kwargs):
completion += chunk.text
else:
provider = self._get_provider()
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages
)
message = AIMessage(content=completion)
return [ChatResult(generations=[ChatGeneration(message=message)])]
provider = self._get_provider()
results = []
for msgs_prompt in messages:
prompt = ChatPromptAdapter.convert_messages_to_prompt(provider, msgs_prompt)
params: Dict[str, Any] = {**kwargs}
if stop:
@@ -89,9 +90,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
completion = self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **params
)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
message = AIMessage(content=completion)
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
return results
def get_num_tokens(self, text: str) -> int:
if self._model_is_anthropic:

View File

@@ -168,29 +168,36 @@ class ChatCohere(BaseChatModel, BaseCohere):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
) -> List[ChatResult]:
if self.streaming and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request)
results = []
for msgs_prompt in messages:
request = get_cohere_chat_request(
msgs_prompt, **self._default_params, **kwargs
)
response = self.client.chat(**request)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
message = AIMessage(content=response.text)
generation_info = None
if hasattr(response, "documents"):
generation_info = {"documents": response.documents}
results.append(
ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
]
)
)
return results
async def _agenerate(
self,

View File

@@ -150,33 +150,37 @@ class ErnieBotChat(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
if self.streaming:
raise ValueError("`streaming` option currently unsupported.")
if not self.access_token:
self._refresh_access_token_with_lock()
payload = {
"messages": [_convert_message_to_dict(m) for m in messages],
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
**kwargs,
}
logger.debug(f"Payload for ernie api is {payload}")
resp = self._chat(payload)
if resp.get("error_code"):
if resp.get("error_code") == 111:
logger.debug("access_token expired, refresh it")
self._refresh_access_token_with_lock()
resp = self._chat(payload)
else:
raise ValueError(f"Error from ErnieChat api response: {resp}")
return self._create_chat_result(resp)
results = []
for msgs_prompt in messages:
payload = {
"messages": [_convert_message_to_dict(m) for m in msgs_prompt],
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
**kwargs,
}
logger.debug(f"Payload for ernie api is {payload}")
resp = self._chat(payload)
if resp.get("error_code"):
if resp.get("error_code") == 111:
logger.debug("access_token expired, refresh it")
self._refresh_access_token_with_lock()
resp = self._chat(payload)
else:
raise ValueError(f"Error from ErnieChat api response: {resp}")
results.append(self._create_chat_result(resp))
return results
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = [

View File

@@ -22,18 +22,21 @@ class FakeMessagesListChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])
) -> List[ChatResult]:
results = []
for _ in messages:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
results.append(ChatResult(generations=[generation]))
return results
@property
def _llm_type(self) -> str:

View File

@@ -113,22 +113,25 @@ class ChatFireworks(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = self._create_message_dicts(messages)
) -> List[ChatResult]:
results = []
for msgs_prompt in messages:
message_dicts = self._create_message_dicts(msgs_prompt)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
return self._create_chat_result(response)
params = {
"model": self.model,
"messages": message_dicts,
**self.model_kwargs,
}
response = completion_with_retry(
self, run_manager=run_manager, stop=stop, **params
)
results.append(self._create_chat_result(response))
return results
async def _agenerate(
self,

View File

@@ -104,23 +104,26 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if should_stream and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
payload = self._build_payload(messages)
response = self._client.chat(payload)
results = []
for msgs_prompt in messages:
payload = self._build_payload(msgs_prompt)
response = self._client.chat(payload)
return self._create_chat_result(response)
results.append(self._create_chat_result(response))
return results
async def _agenerate(
self,

View File

@@ -287,25 +287,28 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
) -> List[ChatResult]:
results = []
for msgs_prompt in messages:
prompt = _messages_to_prompt_dict(msgs_prompt)
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
results.append(_response_to_result(response, stop))
return results
async def _agenerate(
self,

View File

@@ -89,11 +89,11 @@ class HumanInputChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""
Displays the messages to the user and returns their input as a response.
@@ -105,9 +105,12 @@ class HumanInputChatModel(BaseChatModel):
Returns:
ChatResult: The user's input as a response.
"""
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])
results = []
for msgs_prompt in messages:
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
results.append(ChatResult(generations=[ChatGeneration(message=user_input)]))
return results
async def _agenerate(
self,

View File

@@ -240,25 +240,28 @@ class ChatHunyuan(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
) -> List[ChatResult]:
if self.streaming and len(messages) == 1:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
res = self._chat(messages, **kwargs)
results = []
for msgs_prompt in messages:
res = self._chat(msgs_prompt, **kwargs)
response = res.json()
response = res.json()
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
if "error" in response:
raise ValueError(f"Error from Hunyuan api response: {response}")
return _create_chat_result(response)
results.append(_create_chat_result(response))
return results
def _stream(
self,

View File

@@ -101,23 +101,26 @@ class ChatJavelinAIGateway(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
) -> [ChatResult]:
results = []
for msgs_prompt in messages:
message_dicts = [
ChatJavelinAIGateway._convert_message_to_dict(msgs_prompt)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = self.client.query_route(self.route, query_body=data)
resp = self.client.query_route(self.route, query_body=data)
return ChatJavelinAIGateway._create_chat_result(resp.dict())
results.append(ChatJavelinAIGateway._create_chat_result(resp.dict()))
return results
async def _agenerate(
self,

View File

@@ -319,21 +319,24 @@ class JinaChat(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
) -> List[ChatResult]:
if self.streaming and len(messages) == 1:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
results = []
for msgs_prompt in messages:
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
results.append(self._create_chat_result(response))
return results
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]

View File

@@ -289,25 +289,27 @@ class ChatLiteLLM(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if should_stream and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
return [_generate_from_stream(stream_iter)]
results = []
for msgs_prompt in messages:
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
results.append(self._create_chat_result(response))
return results
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []

View File

@@ -51,11 +51,11 @@ class MiniMaxChat(MinimaxCommon, BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages. Code chat
@@ -73,13 +73,16 @@ class MiniMaxChat(MinimaxCommon, BaseChatModel):
raise ValueError(
"You should provide at least one message to start the chat!"
)
history = _parse_chat_history(messages)
payload = self._default_params
payload["messages"] = history
text = self._client.post(payload)
results = []
for msgs_prompt in messages:
history = _parse_chat_history(msgs_prompt)
payload = self._default_params
payload["messages"] = history
text = self._client.post(payload)
# This is required since the stop are not enforced by the model parameters
return text if stop is None else enforce_stop_tokens(text, stop)
# This is required since the stop are not enforced by the model parameters
results.append(text if stop is None else enforce_stop_tokens(text, stop))
return results
async def _agenerate(
self,

View File

@@ -85,11 +85,11 @@ class ChatMLflowAIGateway(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
try:
import mlflow.gateway
except ImportError as e:
@@ -98,17 +98,20 @@ class ChatMLflowAIGateway(BaseChatModel):
"Please install it with `pip install mlflow[gateway]`."
) from e
message_dicts = [
ChatMLflowAIGateway._convert_message_to_dict(message)
for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
results = []
for msgs_prompt in messages:
message_dicts = [
ChatMLflowAIGateway._convert_message_to_dict(message)
for message in msgs_prompt
]
data: Dict[str, Any] = {
"messages": message_dicts,
**(self.params.dict() if self.params else {}),
}
resp = mlflow.gateway.query(self.route, data=data)
return ChatMLflowAIGateway._create_chat_result(resp)
resp = mlflow.gateway.query(self.route, data=data)
results.append(ChatMLflowAIGateway._create_chat_result(resp))
return results
async def _agenerate(
self,

View File

@@ -72,11 +72,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Call out to Ollama's generate endpoint.
Args:

View File

@@ -343,24 +343,27 @@ class ChatOpenAI(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if should_stream and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
return [_generate_from_stream(stream_iter)]
results = []
for msgs_prompt in messages:
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
results.append(self._create_chat_result(response))
return results
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]

View File

@@ -182,15 +182,20 @@ class PaiEasChatEndpoint(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
) -> List[ChatResult]:
results = []
for msgs_prompt in messages:
output_str = self._call(
msgs_prompt, stop=stop, run_manager=run_manager, **kwargs
)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
results.append(ChatResult(generations=[generation]))
return results
def _call(
self,

View File

@@ -306,32 +306,35 @@ class ChatTongyi(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
should_stream = stream if stream is not None else self.streaming
if should_stream:
if should_stream and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
if not messages:
raise ValueError("No messages provided.")
message_dicts, params = self._create_message_dicts(messages, stop)
results = []
for msgs_prompt in messages:
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
if message_dicts[-1]["role"] != "user":
raise ValueError("Last message should be user message.")
if message_dicts[-1]["role"] != "user":
raise ValueError("Last message should be user message.")
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
results.append(self._create_chat_result(response))
return results
def _stream(
self,

View File

@@ -145,12 +145,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Generate next turn in the conversation.
Args:
@@ -167,11 +167,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
ValueError: if the last message in the list is not from human.
"""
should_stream = stream if stream is not None else self.streaming
if should_stream:
if should_stream and len(messages) == 1:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
messages[0], stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)
return [_generate_from_stream(stream_iter)]
question = _get_question(messages)
history = _parse_chat_history(messages[:-1])

View File

@@ -63,11 +63,11 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> List[ChatResult]:
"""Generate next turn in the conversation.
Args:
messages: The history of the conversation as a list of messages.
@@ -96,28 +96,31 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
raise ValueError(
"You should provide at least one message to start the chat!"
)
message_history, instruction = _parse_chat_history(messages)
channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(self.url, channel_credentials)
request = ChatRequest(
model=self.model_name,
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
),
instruction_text=instruction,
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationServiceStub(channel)
if self.iam_token:
metadata = (("authorization", f"Bearer {self.iam_token}"),)
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
res = stub.Chat(request, metadata=metadata)
text = list(res)[0].message.text
text = text if stop is None else enforce_stop_tokens(text, stop)
message = AIMessage(content=text)
return ChatResult(generations=[ChatGeneration(message=message)])
results = []
for msgs_prompt in messages:
message_history, instruction = _parse_chat_history(msgs_prompt)
channel_credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(self.url, channel_credentials)
request = ChatRequest(
model=self.model_name,
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
),
instruction_text=instruction,
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationServiceStub(channel)
if self.iam_token:
metadata = (("authorization", f"Bearer {self.iam_token}"),)
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
res = stub.Chat(request, metadata=metadata)
text = list(res)[0].message.text
text = text if stop is None else enforce_stop_tokens(text, stop)
message = AIMessage(content=text)
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
return results
async def _agenerate(
self,