feat: Implement stream interface (#15875)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

Major changes:

- Rename `wasm_chat.py` to `llama_edge.py`
- Rename the `WasmChatService` class to `ChatService`
- Implement the `stream` interface for `ChatService`
- Add `test_chat_wasm_service_streaming` in the integration test
- Update `llama_edge.ipynb`

---------

Signed-off-by: Xin Liu <sam@secondstate.io>
This commit is contained in:
Xin Liu
2024-01-12 13:32:48 +08:00
committed by GitHub
parent ec4dab0449
commit 5efec068c9
8 changed files with 299 additions and 128 deletions

View File

@@ -39,6 +39,7 @@ from langchain_community.chat_models.javelin_ai_gateway import ChatJavelinAIGate
from langchain_community.chat_models.jinachat import JinaChat
from langchain_community.chat_models.konko import ChatKonko
from langchain_community.chat_models.litellm import ChatLiteLLM
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
from langchain_community.chat_models.minimax import MiniMaxChat
from langchain_community.chat_models.mlflow import ChatMlflow
from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
@@ -49,12 +50,11 @@ from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOp
from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_community.chat_models.vertexai import ChatVertexAI
from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat
from langchain_community.chat_models.wasm_chat import WasmChatService
from langchain_community.chat_models.yandex import ChatYandexGPT
from langchain_community.chat_models.zhipuai import ChatZhipuAI
__all__ = [
"WasmChatService",
"LlamaEdgeChatService",
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",

View File

@@ -1,18 +1,26 @@
import json
import logging
from typing import Any, Dict, List, Mapping, Optional
import re
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_pydantic_field_names
@@ -45,10 +53,26 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return message_dict
class WasmChatService(BaseChatModel):
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
class LlamaEdgeChatService(BaseChatModel):
"""Chat with LLMs via `llama-api-server`
For the information about `llama-api-server`, visit https://github.com/second-state/llama-utils
For the information about `llama-api-server`, visit https://github.com/second-state/LlamaEdge
"""
request_timeout: int = 60
@@ -57,6 +81,8 @@ class WasmChatService(BaseChatModel):
"""URL of WasmChat service"""
model: str = "NA"
"""model name, default is `NA`."""
streaming: bool = False
"""Whether to stream the results or not."""
class Config:
"""Configuration for this pydantic object."""
@@ -96,6 +122,12 @@ class WasmChatService(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
res = self._chat(messages, **kwargs)
if res.status_code != 200:
@@ -105,6 +137,64 @@ class WasmChatService(BaseChatModel):
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
res = self._chat(messages, **kwargs)
default_chunk_class = AIMessageChunk
substring = '"object":"chat.completion.chunk"}'
for line in res.iter_lines():
chunks = []
if line:
json_string = line.decode("utf-8")
# Find all positions of the substring
positions = [m.start() for m in re.finditer(substring, json_string)]
positions = [-1 * len(substring)] + positions
for i in range(len(positions) - 1):
chunk = json.loads(
json_string[
positions[i] + len(substring) : positions[i + 1]
+ len(substring)
]
)
chunks.append(chunk)
for chunk in chunks:
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
if (
choice.get("finish_reason") is not None
and choice.get("finish_reason") == "stop"
):
break
finish_reason = choice.get("finish_reason")
generation_info = (
dict(finish_reason=finish_reason)
if finish_reason is not None
else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None:
res = requests.models.Response()
@@ -114,10 +204,17 @@ class WasmChatService(BaseChatModel):
service_url = f"{self.service_url}/v1/chat/completions"
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
}
if self.streaming:
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
"stream": self.streaming,
}
else:
payload = {
"model": self.model,
"messages": [_convert_message_to_dict(m) for m in messages],
}
res = requests.post(
url=service_url,

View File

@@ -0,0 +1,52 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models.llama_edge import LlamaEdgeChatService
@pytest.mark.enable_socket
def test_chat_wasm_service() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = LlamaEdgeChatService(service_url=service_url)
# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]
# chat with wasm-chat service
response = chat(messages)
# check response
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content
@pytest.mark.enable_socket
def test_chat_wasm_service_streaming() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = LlamaEdgeChatService(service_url=service_url, streaming=True)
# create message sequence
user_message = HumanMessage(content="What is the capital of France?")
messages = [
user_message,
]
output = ""
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
output += chunk.content
assert "Paris" in output

View File

@@ -1,28 +0,0 @@
import pytest
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models.wasm_chat import WasmChatService
@pytest.mark.enable_socket
def test_chat_wasm_service() -> None:
"""This test requires the port 8080 is not occupied."""
# service url
service_url = "https://b008-54-186-154-209.ngrok-free.app"
# create wasm-chat service instance
chat = WasmChatService(service_url=service_url)
# create message sequence
system_message = SystemMessage(content="You are an AI assistant")
user_message = HumanMessage(content="What is the capital of France?")
messages = [system_message, user_message]
# chat with wasm-chat service
response = chat(messages)
# check response
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert "Paris" in response.content

View File

@@ -33,7 +33,7 @@ EXPECTED_ALL = [
"ChatHunyuan",
"GigaChat",
"VolcEngineMaasChat",
"WasmChatService",
"LlamaEdgeChatService",
"GPTRouter",
"ChatZhipuAI",
]

View File

@@ -7,8 +7,8 @@ from langchain_core.messages import (
SystemMessage,
)
from langchain_community.chat_models.wasm_chat import (
WasmChatService,
from langchain_community.chat_models.llama_edge import (
LlamaEdgeChatService,
_convert_dict_to_message,
_convert_message_to_dict,
)
@@ -64,7 +64,7 @@ def test__convert_dict_to_message_other_role() -> None:
def test_wasm_chat_without_service_url() -> None:
chat = WasmChatService()
chat = LlamaEdgeChatService()
# create message sequence
system_message = SystemMessage(content="You are an AI assistant")