mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
community[patch]: Fix ChatDatabricsk in case that streaming response doesn't have role field in delta chunk (#21897)
Thank you for contributing to LangChain! - [X] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" **Description:** Fix ChatDatabricsk in case that streaming response doesn't have role field in delta chunk - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [X] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
parent
aed64daabb
commit
b0ef5e778a
@ -174,9 +174,16 @@ class ChatMlflow(BaseChatModel):
|
||||
)
|
||||
# TODO: check if `_client.predict_stream` is available.
|
||||
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data)
|
||||
first_chunk_role = None
|
||||
for chunk in chunk_iter:
|
||||
choice = chunk["choices"][0]
|
||||
chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"])
|
||||
|
||||
chunk_delta = choice["delta"]
|
||||
if first_chunk_role is None:
|
||||
first_chunk_role = chunk_delta.get("role")
|
||||
chunk = ChatMlflow._convert_delta_to_message_chunk(
|
||||
chunk_delta, first_chunk_role
|
||||
)
|
||||
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
@ -225,8 +232,10 @@ class ChatMlflow(BaseChatModel):
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
@staticmethod
|
||||
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
role = _dict["role"]
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_role: str
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role", default_role)
|
||||
content = _dict["content"]
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
|
Loading…
Reference in New Issue
Block a user