mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 00:04:19 +00:00
community[patch]: add missing chunk parameter for _stream/_astream (#17807)
- Description: Add missing chunk parameter for _stream/_astream for some chat models, make all chat models in a consistent behaviour. - Issue: N/A - Dependencies: N/A
This commit is contained in:
parent
1b0802babe
commit
31891092d8
@ -218,9 +218,10 @@ class ChatBaichuan(BaseChatModel):
|
||||
m.get("delta"), default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
parameters = {**self._default_params, **kwargs}
|
||||
|
@ -147,9 +147,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -164,9 +165,10 @@ class ChatCohere(BaseChatModel, BaseCohere):
|
||||
async for data in stream:
|
||||
if data.event_type == "text-generation":
|
||||
delta = data.text
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(delta)
|
||||
await run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
|
||||
def _get_generation_info(self, response: Any) -> Dict[str, Any]:
|
||||
"""Get the generation info from cohere API response."""
|
||||
|
@ -328,9 +328,10 @@ class ChatDeepInfra(BaseChatModel):
|
||||
for line in _parse_stream(response.iter_lines()):
|
||||
chunk = _handle_sse_line(line)
|
||||
if chunk:
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content))
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -350,9 +351,12 @@ class ChatDeepInfra(BaseChatModel):
|
||||
async for line in _parse_stream_async(response.content):
|
||||
chunk = _handle_sse_line(line)
|
||||
if chunk:
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(str(chunk.content))
|
||||
await run_manager.on_llm_new_token(
|
||||
str(chunk.content), chunk=cg_chunk
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -154,9 +154,10 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
for chunk in self._client.stream(payload):
|
||||
if chunk.choices:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -170,9 +171,10 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
async for chunk in self._client.astream(payload):
|
||||
if chunk.choices:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content)
|
||||
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Count approximate number of tokens"""
|
||||
|
@ -275,9 +275,10 @@ class ChatHunyuan(BaseChatModel):
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
|
||||
if self.hunyuan_secret_key is None:
|
||||
|
@ -312,9 +312,10 @@ class JinaChat(BaseChatModel):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -371,9 +372,10 @@ class JinaChat(BaseChatModel):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -355,9 +355,10 @@ class ChatLiteLLM(BaseChatModel):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -378,9 +379,10 @@ class ChatLiteLLM(BaseChatModel):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -123,9 +123,10 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content, **params)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -148,9 +149,12 @@ class ChatLiteLLMRouter(ChatLiteLLM):
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content, **params)
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.content, chunk=cg_chunk, **params
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -195,6 +195,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
@ -221,6 +222,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
chunk=chunk,
|
||||
verbose=verbose,
|
||||
)
|
||||
if final_chunk is None:
|
||||
|
@ -291,9 +291,12 @@ class PaiEasChatEndpoint(BaseChatModel):
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
cg_chunk = ChatGenerationChunk(message=content)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(cast(str, content.content))
|
||||
yield ChatGenerationChunk(message=content)
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, content.content), chunk=cg_chunk
|
||||
)
|
||||
yield cg_chunk
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
|
@ -224,9 +224,10 @@ class ChatSparkLLM(BaseChatModel):
|
||||
continue
|
||||
delta = content["data"]
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
cg_chunk = ChatGenerationChunk(message=chunk)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(str(chunk.content))
|
||||
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
@ -376,9 +376,10 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
|
||||
chat = self._start_chat(history, **params)
|
||||
responses = chat.send_message_streaming(question.content, **params)
|
||||
for response in responses:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(response.text)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
|
||||
run_manager.on_llm_new_token(response.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
def _start_chat(
|
||||
self, history: _ChatHistory, **kwargs: Any
|
||||
|
@ -116,9 +116,10 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
|
||||
for res in self.client.stream_chat(params):
|
||||
if res:
|
||||
msg = convert_dict_to_message(res)
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(cast(str, msg.content))
|
||||
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
@ -269,12 +269,13 @@ class ChatYuan2(BaseChatModel):
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.content)
|
||||
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@ -351,12 +352,13 @@ class ChatYuan2(BaseChatModel):
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
default_chunk_class = chunk.__class__
|
||||
yield ChatGenerationChunk(
|
||||
cg_chunk = ChatGenerationChunk(
|
||||
message=chunk,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
yield cg_chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.content)
|
||||
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -327,9 +327,10 @@ class ChatZhipuAI(BaseChatModel):
|
||||
for r in response.events():
|
||||
if r.event == "add":
|
||||
delta = r.data
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(delta)
|
||||
run_manager.on_llm_new_token(delta, chunk=chunk)
|
||||
|
||||
elif r.event == "error":
|
||||
raise ValueError(f"Error from ZhipuAI API response: {r.data}")
|
||||
|
Loading…
Reference in New Issue
Block a user