community[patch]: Fix ChatDatabricsk in case that streaming response doesn't have role field in delta chunk (#21897)

Thank you for contributing to LangChain!

- [X] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


**Description:**
Fix ChatDatabricsk in case that streaming response doesn't have role
field in delta chunk


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [X] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
This commit is contained in:
WeichenXu 2024-05-22 23:12:53 +08:00 committed by GitHub
parent aed64daabb
commit b0ef5e778a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,9 +174,16 @@ class ChatMlflow(BaseChatModel):
)
# TODO: check if `_client.predict_stream` is available.
chunk_iter = self._client.predict_stream(endpoint=self.endpoint, inputs=data)
first_chunk_role = None
for chunk in chunk_iter:
choice = chunk["choices"][0]
chunk = ChatMlflow._convert_delta_to_message_chunk(choice["delta"])
chunk_delta = choice["delta"]
if first_chunk_role is None:
first_chunk_role = chunk_delta.get("role")
chunk = ChatMlflow._convert_delta_to_message_chunk(
chunk_delta, first_chunk_role
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
@ -225,8 +232,10 @@ class ChatMlflow(BaseChatModel):
return ChatMessage(content=content, role=role)
@staticmethod
def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
role = _dict["role"]
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict["content"]
if role == "user":
return HumanMessageChunk(content=content)