mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-13 14:21:27 +00:00
Compare commits
1 Commits
langchain-
...
bagatur/rf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0a7fab51b |
@@ -160,30 +160,31 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
prompt = self._convert_messages_to_prompt(
|
||||
messages,
|
||||
)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
prompt = self._convert_messages_to_prompt(msgs_prompt)
|
||||
params: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop:
|
||||
params["stop_sequences"] = stop
|
||||
response = self.client.completions.create(**params)
|
||||
completion = response.completion
|
||||
message = AIMessage(content=completion)
|
||||
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -195,25 +195,26 @@ class ChatBaichuan(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
res = self._chat(msgs_prompt, **kwargs)
|
||||
response = res.json()
|
||||
if response.get("code") != 0:
|
||||
raise ValueError(f"Error from Baichuan api response: {response}")
|
||||
|
||||
response = res.json()
|
||||
|
||||
if response.get("code") != 0:
|
||||
raise ValueError(f"Error from Baichuan api response: {response}")
|
||||
|
||||
return self._create_chat_result(response)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -221,11 +221,11 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
messages: The messages to pass into the model.
|
||||
@@ -237,32 +237,37 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
.. code-block:: python
|
||||
response = qianfan_model("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
if self.streaming and len(messages) == 1:
|
||||
completion = ""
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
for chunk in self._stream(messages[0], stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="stop"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
return [
|
||||
ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
)
|
||||
]
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
params = self._convert_prompt_msg_params(msgs_prompt, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = _convert_dict_to_message(response_payload)
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info={
|
||||
"finish_reason": "stop",
|
||||
**response_payload.get("body", {}),
|
||||
},
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=[gen], llm_output=llm_output)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
results.append(ChatResult(generations=[gen], llm_output=llm_output))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -472,11 +472,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
|
||||
def _generate_with_cache(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
@@ -496,19 +496,33 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
results = []
|
||||
to_generate = []
|
||||
for i, msgs_prompt in enumerate(messages):
|
||||
prompt = dumps(msgs_prompt)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
results.append((i, ChatResult(generations=cache_val)))
|
||||
else:
|
||||
to_generate.append((i, msgs_prompt))
|
||||
if to_generate:
|
||||
generate_idxs, generate_messages = tuple(zip(*to_generate))
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
generations = self._generate(
|
||||
list(generate_messages),
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
generations = self._generate(
|
||||
list(generate_messages), stop=stop, **kwargs
|
||||
)
|
||||
for msgs_prompt, chat_res in zip(generate_messages, generations):
|
||||
prompt = dumps(msgs_prompt)
|
||||
llm_cache.update(prompt, llm_string, chat_res.generations)
|
||||
results.extend(zip(generate_idxs, generations))
|
||||
return [res for _, res in sorted(results)]
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
@@ -553,20 +567,20 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Top Level call"""
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Top Level call"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self._generate, **kwargs), messages, stop, run_manager
|
||||
@@ -696,15 +710,20 @@ class SimpleChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
) -> List[ChatResult]:
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
output_str = self._call(
|
||||
msgs_prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
results.append(ChatResult(generations=[generation]))
|
||||
return results
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
|
||||
@@ -66,21 +66,22 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
completion = ""
|
||||
|
||||
if self.streaming:
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
if self.streaming and len(messages) == 1:
|
||||
for chunk in self._stream(messages[0], stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
else:
|
||||
provider = self._get_provider()
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||
provider=provider, messages=messages
|
||||
)
|
||||
message = AIMessage(content=completion)
|
||||
return [ChatResult(generations=[ChatGeneration(message=message)])]
|
||||
provider = self._get_provider()
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
prompt = ChatPromptAdapter.convert_messages_to_prompt(provider, msgs_prompt)
|
||||
|
||||
params: Dict[str, Any] = {**kwargs}
|
||||
if stop:
|
||||
@@ -89,9 +90,9 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
||||
completion = self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||
)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
message = AIMessage(content=completion)
|
||||
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
|
||||
return results
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
if self._model_is_anthropic:
|
||||
|
||||
@@ -168,29 +168,36 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
|
||||
response = self.client.chat(**request)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
request = get_cohere_chat_request(
|
||||
msgs_prompt, **self._default_params, **kwargs
|
||||
)
|
||||
response = self.client.chat(**request)
|
||||
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
message = AIMessage(content=response.text)
|
||||
generation_info = None
|
||||
if hasattr(response, "documents"):
|
||||
generation_info = {"documents": response.documents}
|
||||
results.append(
|
||||
ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(message=message, generation_info=generation_info)
|
||||
]
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -150,33 +150,37 @@ class ErnieBotChat(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming:
|
||||
raise ValueError("`streaming` option currently unsupported.")
|
||||
|
||||
if not self.access_token:
|
||||
self._refresh_access_token_with_lock()
|
||||
payload = {
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
**kwargs,
|
||||
}
|
||||
logger.debug(f"Payload for ernie api is {payload}")
|
||||
resp = self._chat(payload)
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
logger.debug("access_token expired, refresh it")
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._chat(payload)
|
||||
else:
|
||||
raise ValueError(f"Error from ErnieChat api response: {resp}")
|
||||
return self._create_chat_result(resp)
|
||||
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
payload = {
|
||||
"messages": [_convert_message_to_dict(m) for m in msgs_prompt],
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
**kwargs,
|
||||
}
|
||||
logger.debug(f"Payload for ernie api is {payload}")
|
||||
resp = self._chat(payload)
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
logger.debug("access_token expired, refresh it")
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._chat(payload)
|
||||
else:
|
||||
raise ValueError(f"Error from ErnieChat api response: {resp}")
|
||||
results.append(self._create_chat_result(resp))
|
||||
return results
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = [
|
||||
|
||||
@@ -22,18 +22,21 @@ class FakeMessagesListChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
generation = ChatGeneration(message=response)
|
||||
return ChatResult(generations=[generation])
|
||||
) -> List[ChatResult]:
|
||||
results = []
|
||||
for _ in messages:
|
||||
response = self.responses[self.i]
|
||||
if self.i < len(self.responses) - 1:
|
||||
self.i += 1
|
||||
else:
|
||||
self.i = 0
|
||||
generation = ChatGeneration(message=response)
|
||||
results.append(ChatResult(generations=[generation]))
|
||||
return results
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
||||
@@ -113,22 +113,25 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
) -> List[ChatResult]:
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts = self._create_message_dicts(msgs_prompt)
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": message_dicts,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
response = completion_with_retry(
|
||||
self, run_manager=run_manager, stop=stop, **params
|
||||
)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -104,23 +104,26 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if should_stream and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
payload = self._build_payload(messages)
|
||||
response = self._client.chat(payload)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
payload = self._build_payload(msgs_prompt)
|
||||
response = self._client.chat(payload)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -287,25 +287,28 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
prompt = _messages_to_prompt_dict(messages)
|
||||
) -> List[ChatResult]:
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
prompt = _messages_to_prompt_dict(msgs_prompt)
|
||||
|
||||
response: genai.types.ChatResponse = chat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
response: genai.types.ChatResponse = chat_with_retry(
|
||||
self,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
candidate_count=self.n,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
results.append(_response_to_result(response, stop))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -89,11 +89,11 @@ class HumanInputChatModel(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""
|
||||
Displays the messages to the user and returns their input as a response.
|
||||
|
||||
@@ -105,9 +105,12 @@ class HumanInputChatModel(BaseChatModel):
|
||||
Returns:
|
||||
ChatResult: The user's input as a response.
|
||||
"""
|
||||
self.message_func(messages, **self.message_kwargs)
|
||||
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
||||
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
self.message_func(messages, **self.message_kwargs)
|
||||
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
||||
results.append(ChatResult(generations=[ChatGeneration(message=user_input)]))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -240,25 +240,28 @@ class ChatHunyuan(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
res = self._chat(messages, **kwargs)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
res = self._chat(msgs_prompt, **kwargs)
|
||||
|
||||
response = res.json()
|
||||
response = res.json()
|
||||
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error from Hunyuan api response: {response}")
|
||||
if "error" in response:
|
||||
raise ValueError(f"Error from Hunyuan api response: {response}")
|
||||
|
||||
return _create_chat_result(response)
|
||||
results.append(_create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -101,23 +101,26 @@ class ChatJavelinAIGateway(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
) -> [ChatResult]:
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts = [
|
||||
ChatJavelinAIGateway._convert_message_to_dict(msgs_prompt)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = self.client.query_route(self.route, query_body=data)
|
||||
resp = self.client.query_route(self.route, query_body=data)
|
||||
|
||||
return ChatJavelinAIGateway._create_chat_result(resp.dict())
|
||||
results.append(ChatJavelinAIGateway._create_chat_result(resp.dict()))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -319,21 +319,24 @@ class JinaChat(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
) -> List[ChatResult]:
|
||||
if self.streaming and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages=messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
|
||||
@@ -289,25 +289,27 @@ class ChatLiteLLM(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if should_stream and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
|
||||
@@ -51,11 +51,11 @@ class MiniMaxChat(MinimaxCommon, BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Generate next turn in the conversation.
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages. Code chat
|
||||
@@ -73,13 +73,16 @@ class MiniMaxChat(MinimaxCommon, BaseChatModel):
|
||||
raise ValueError(
|
||||
"You should provide at least one message to start the chat!"
|
||||
)
|
||||
history = _parse_chat_history(messages)
|
||||
payload = self._default_params
|
||||
payload["messages"] = history
|
||||
text = self._client.post(payload)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
history = _parse_chat_history(msgs_prompt)
|
||||
payload = self._default_params
|
||||
payload["messages"] = history
|
||||
text = self._client.post(payload)
|
||||
|
||||
# This is required since the stop are not enforced by the model parameters
|
||||
return text if stop is None else enforce_stop_tokens(text, stop)
|
||||
# This is required since the stop are not enforced by the model parameters
|
||||
results.append(text if stop is None else enforce_stop_tokens(text, stop))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -85,11 +85,11 @@ class ChatMLflowAIGateway(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
@@ -98,17 +98,20 @@ class ChatMLflowAIGateway(BaseChatModel):
|
||||
"Please install it with `pip install mlflow[gateway]`."
|
||||
) from e
|
||||
|
||||
message_dicts = [
|
||||
ChatMLflowAIGateway._convert_message_to_dict(message)
|
||||
for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts = [
|
||||
ChatMLflowAIGateway._convert_message_to_dict(message)
|
||||
for message in msgs_prompt
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
|
||||
resp = mlflow.gateway.query(self.route, data=data)
|
||||
return ChatMLflowAIGateway._create_chat_result(resp)
|
||||
resp = mlflow.gateway.query(self.route, data=data)
|
||||
results.append(ChatMLflowAIGateway._create_chat_result(resp))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
@@ -72,11 +72,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Call out to Ollama's generate endpoint.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -343,24 +343,27 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if should_stream and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
|
||||
@@ -182,15 +182,20 @@ class PaiEasChatEndpoint(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
) -> List[ChatResult]:
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
output_str = self._call(
|
||||
msgs_prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
results.append(ChatResult(generations=[generation]))
|
||||
return results
|
||||
|
||||
def _call(
|
||||
self,
|
||||
|
||||
@@ -306,32 +306,35 @@ class ChatTongyi(BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if should_stream and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
if not messages:
|
||||
raise ValueError("No messages provided.")
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_dicts, params = self._create_message_dicts(msgs_prompt, stop)
|
||||
|
||||
if message_dicts[-1]["role"] != "user":
|
||||
raise ValueError("Last message should be user message.")
|
||||
if message_dicts[-1]["role"] != "user":
|
||||
raise ValueError("Last message should be user message.")
|
||||
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
params = {**params, **kwargs}
|
||||
response = self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
)
|
||||
results.append(self._create_chat_result(response))
|
||||
return results
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
||||
@@ -145,12 +145,12 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Generate next turn in the conversation.
|
||||
|
||||
Args:
|
||||
@@ -167,11 +167,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
ValueError: if the last message in the list is not from human.
|
||||
"""
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
if should_stream and len(messages) == 1:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
messages[0], stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return _generate_from_stream(stream_iter)
|
||||
return [_generate_from_stream(stream_iter)]
|
||||
|
||||
question = _get_question(messages)
|
||||
history = _parse_chat_history(messages[:-1])
|
||||
|
||||
@@ -63,11 +63,11 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
) -> List[ChatResult]:
|
||||
"""Generate next turn in the conversation.
|
||||
Args:
|
||||
messages: The history of the conversation as a list of messages.
|
||||
@@ -96,28 +96,31 @@ class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
||||
raise ValueError(
|
||||
"You should provide at least one message to start the chat!"
|
||||
)
|
||||
message_history, instruction = _parse_chat_history(messages)
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||
request = ChatRequest(
|
||||
model=self.model_name,
|
||||
generation_options=GenerationOptions(
|
||||
temperature=DoubleValue(value=self.temperature),
|
||||
max_tokens=Int64Value(value=self.max_tokens),
|
||||
),
|
||||
instruction_text=instruction,
|
||||
messages=[Message(**message) for message in message_history],
|
||||
)
|
||||
stub = TextGenerationServiceStub(channel)
|
||||
if self.iam_token:
|
||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||
else:
|
||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||
res = stub.Chat(request, metadata=metadata)
|
||||
text = list(res)[0].message.text
|
||||
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||
message = AIMessage(content=text)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
results = []
|
||||
for msgs_prompt in messages:
|
||||
message_history, instruction = _parse_chat_history(msgs_prompt)
|
||||
channel_credentials = grpc.ssl_channel_credentials()
|
||||
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||
request = ChatRequest(
|
||||
model=self.model_name,
|
||||
generation_options=GenerationOptions(
|
||||
temperature=DoubleValue(value=self.temperature),
|
||||
max_tokens=Int64Value(value=self.max_tokens),
|
||||
),
|
||||
instruction_text=instruction,
|
||||
messages=[Message(**message) for message in message_history],
|
||||
)
|
||||
stub = TextGenerationServiceStub(channel)
|
||||
if self.iam_token:
|
||||
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||
else:
|
||||
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||
res = stub.Chat(request, metadata=metadata)
|
||||
text = list(res)[0].message.text
|
||||
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||
message = AIMessage(content=text)
|
||||
results.append(ChatResult(generations=[ChatGeneration(message=message)]))
|
||||
return results
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user