community[patch]: add missing chunk parameter for _stream/_astream (#17807)

- Description: Add missing chunk parameter for _stream/_astream for some
chat models, make all chat models in a consistent behaviour.
- Issue: N/A
- Dependencies: N/A
This commit is contained in:
mackong 2024-02-22 07:32:28 +08:00 committed by GitHub
parent 1b0802babe
commit 31891092d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 71 additions and 42 deletions

View File

@ -218,9 +218,10 @@ class ChatBaichuan(BaseChatModel):
m.get("delta"), default_chunk_class
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}

View File

@ -147,9 +147,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta)
run_manager.on_llm_new_token(delta, chunk=chunk)
async def _astream(
self,
@ -164,9 +165,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta)
await run_manager.on_llm_new_token(delta, chunk=chunk)
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""

View File

@ -328,9 +328,10 @@ class ChatDeepInfra(BaseChatModel):
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content))
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
async def _astream(
self,
@ -350,9 +351,12 @@ class ChatDeepInfra(BaseChatModel):
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(str(chunk.content))
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk
)
async def _agenerate(
self,

View File

@ -154,9 +154,10 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
for chunk in self._client.stream(payload):
if chunk.choices:
content = chunk.choices[0].delta.content
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(content)
run_manager.on_llm_new_token(content, chunk=cg_chunk)
async def _astream(
self,
@ -170,9 +171,10 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
async for chunk in self._client.astream(payload):
if chunk.choices:
content = chunk.choices[0].delta.content
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(content)
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens"""

View File

@ -275,9 +275,10 @@ class ChatHunyuan(BaseChatModel):
choice["delta"], default_chunk_class
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None:

View File

@ -312,9 +312,10 @@ class JinaChat(BaseChatModel):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
def _generate(
self,
@ -371,9 +372,10 @@ class JinaChat(BaseChatModel):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
async def _agenerate(
self,

View File

@ -355,9 +355,10 @@ class ChatLiteLLM(BaseChatModel):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
async def _astream(
self,
@ -378,9 +379,10 @@ class ChatLiteLLM(BaseChatModel):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
async def _agenerate(
self,

View File

@ -123,9 +123,10 @@ class ChatLiteLLMRouter(ChatLiteLLM):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, **params)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)
async def _astream(
self,
@ -148,9 +149,12 @@ class ChatLiteLLMRouter(ChatLiteLLM):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, **params)
await run_manager.on_llm_new_token(
chunk.content, chunk=cg_chunk, **params
)
async def _agenerate(
self,

View File

@ -195,6 +195,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
@ -221,6 +222,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:

View File

@ -291,9 +291,12 @@ class PaiEasChatEndpoint(BaseChatModel):
# yield text, if any
if text:
cg_chunk = ChatGenerationChunk(message=content)
if run_manager:
await run_manager.on_llm_new_token(cast(str, content.content))
yield ChatGenerationChunk(message=content)
await run_manager.on_llm_new_token(
cast(str, content.content), chunk=cg_chunk
)
yield cg_chunk
# break if stop sequence found
if stop_seq_found:

View File

@ -224,9 +224,10 @@ class ChatSparkLLM(BaseChatModel):
continue
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content))
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
def _generate(
self,

View File

@ -376,9 +376,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=response.text))
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
run_manager.on_llm_new_token(response.text, chunk=chunk)
yield chunk
def _start_chat(
self, history: _ChatHistory, **kwargs: Any

View File

@ -116,9 +116,10 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
for res in self.client.stream_chat(params):
if res:
msg = convert_dict_to_message(res)
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
yield chunk
if run_manager:
run_manager.on_llm_new_token(cast(str, msg.content))
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)
def _generate(
self,

View File

@ -269,12 +269,13 @@ class ChatYuan2(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
def _generate(
self,
@ -351,12 +352,13 @@ class ChatYuan2(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
async def _agenerate(
self,

View File

@ -327,9 +327,10 @@ class ChatZhipuAI(BaseChatModel):
for r in response.events():
if r.event == "add":
delta = r.data
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta)
run_manager.on_llm_new_token(delta, chunk=chunk)
elif r.event == "error":
raise ValueError(f"Error from ZhipuAI API response: {r.data}")