Compare commits

...

1 Commits

Author SHA1 Message Date
Erick Friis
3172b8c1ce mistralai[patch]: enforce stop tokens 2024-01-09 19:20:48 -08:00
2 changed files with 130 additions and 18 deletions

View File

@@ -2,11 +2,14 @@ from __future__ import annotations
import importlib.util
import logging
import re
from collections import deque
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Generator,
Iterator,
List,
Optional,
@@ -89,16 +92,19 @@ def _create_retry_decorator(
def _convert_mistral_chat_message_to_message(
_message: MistralChatMessage,
stop: Optional[List[str]] = None,
) -> BaseMessage:
role = _message.role
if role == "user":
return HumanMessage(content=_message.content)
return HumanMessage(content=_enforce_stop_tokens(_message.content, stop))
elif role == "assistant":
return AIMessage(content=_message.content)
return AIMessage(content=_enforce_stop_tokens(_message.content, stop))
elif role == "system":
return SystemMessage(content=_message.content)
return SystemMessage(content=_enforce_stop_tokens(_message.content, stop))
else:
return ChatMessage(content=_message.content, role=role)
return ChatMessage(
content=_enforce_stop_tokens(_message.content, stop), role=role
)
async def acompletion_with_retry(
@@ -153,6 +159,63 @@ def _convert_message_to_mistral_chat_message(
return mistral_message
def _enforce_stop_tokens(text: str, stop: Optional[List[str]]) -> str:
if not stop:
return text
regex = "|".join([re.escape(s) for s in stop])
return re.split(regex, text, maxsplit=1)[0]
def _enforce_stop_tokens_stream(
gen: Generator[ChatGenerationChunk, None, None], stop: List[str]
) -> Generator[ChatGenerationChunk, None, None]:
if not stop:
yield from gen
return
regex = "|".join([re.escape(s) for s in stop])
longest = max(len(s) for s in stop)
accumulator = deque()
chars_to_yield = -1
for chunk in gen:
accumulator.append(chunk)
accumulator_str = "".join([c.message.content for c in accumulator])
match = re.search(regex, accumulator_str)
if match:
chars_to_yield = match.start()
break
accumulator_len = len(accumulator_str)
next_chunk = accumulator[0]
while accumulator_len - len(next_chunk.message.content) > longest:
accumulator.popleft()
accumulator_len -= len(next_chunk.message.content)
yield next_chunk
next_chunk = accumulator[0]
# if nothing matched, yield everything
if chars_to_yield == -1:
for chunk in accumulator:
yield chunk
return
# yield full chunks available
next_chunk = accumulator[0]
while chars_to_yield - len(next_chunk.message.content) >= 0:
yield next_chunk
chars_to_yield -= len(next_chunk.message.content)
accumulator.popleft()
next_chunk = accumulator[0]
# yield partial chunk if necessary
if chars_to_yield > 0:
new_message = next_chunk.message.copy()
new_message.content = new_message.content[:chars_to_yield]
new_chunk = ChatGenerationChunk(message=new_message)
yield new_chunk
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
@@ -260,15 +323,15 @@ class ChatMistralAI(BaseChatModel):
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
response = self.completion_with_retry(
messages=message_dicts, run_manager=run_manager, **params
)
return self._create_chat_result(response)
return self._create_chat_result(response, stop)
def _create_chat_result(
self, response: MistralChatCompletionResponse
self, response: MistralChatCompletionResponse, stop: Optional[List[str]] = None
) -> ChatResult:
generations = []
for res in response.choices:
@@ -276,7 +339,7 @@ class ChatMistralAI(BaseChatModel):
if finish_reason:
finish_reason = finish_reason.value
gen = ChatGeneration(
message=_convert_mistral_chat_message_to_message(res.message),
message=_convert_mistral_chat_message_to_message(res.message, stop),
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
@@ -286,15 +349,9 @@ class ChatMistralAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
self, messages: List[BaseMessage]
) -> Tuple[List[MistralChatMessage], Dict[str, Any]]:
params = self._client_params
if stop is not None or "stop" in params:
if "stop" in params:
params.pop("stop")
logger.warning(
"Parameter `stop` not yet supported (https://docs.mistral.ai/api)"
)
message_dicts = [_convert_message_to_mistral_chat_message(m) for m in messages]
return message_dicts, params
@@ -305,7 +362,7 @@ class ChatMistralAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
@@ -330,7 +387,7 @@ class ChatMistralAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
@@ -363,7 +420,7 @@ class ChatMistralAI(BaseChatModel):
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
message_dicts, params = self._create_message_dicts(messages)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=message_dicts, run_manager=run_manager, **params

View File

@@ -1,14 +1,17 @@
"""Test MistralAI Chat API wrapper."""
import os
from typing import Generator
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGenerationChunk
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.models.chat_completion import ( # type: ignore[import]
@@ -18,6 +21,8 @@ from mistralai.models.chat_completion import ( # type: ignore[import]
from langchain_mistralai.chat_models import ( # type: ignore[import]
ChatMistralAI,
_convert_message_to_mistral_chat_message,
_enforce_stop_tokens,
_enforce_stop_tokens_stream,
)
os.environ["MISTRAL_API_KEY"] = "foo"
@@ -63,3 +68,53 @@ def test_convert_message_to_mistral_chat_message(
) -> None:
result = _convert_message_to_mistral_chat_message(message)
assert result == expected
def test_enforce_stop_tokens() -> None:
"""Test _enforce_stop_tokens helper function."""
assert _enforce_stop_tokens("Hello", ["lo"]) == "Hel"
assert _enforce_stop_tokens("Hello there my friend", ["my"]) == "Hello there "
assert _enforce_stop_tokens("Hello there my friend", ["my", "there"]) == "Hello "
# Test regex special characters
assert (
_enforce_stop_tokens("Hello there? my friend", ["e?", "friend"]) == "Hello ther"
)
def _string_to_2_char_stream(string: str) -> Generator[ChatGenerationChunk, None, None]:
"""Convert a string to a stream of 2 character chunks."""
for i in range(0, len(string), 2):
message = AIMessageChunk(content=string[i : i + 2])
yield ChatGenerationChunk(message=message)
def _stream_to_string(stream: Generator[ChatGenerationChunk, None, None]) -> str:
"""Convert a stream of chunks to a string."""
return "".join([chunk.message.content for chunk in stream])
def test_enforce_stop_tokens_stream() -> None:
"""Test _enforce_stop_tokens_stream helper function."""
assert (
_stream_to_string(
_enforce_stop_tokens_stream(_string_to_2_char_stream("Hello"), ["lo"])
)
== "Hel"
)
assert (
_stream_to_string(
_enforce_stop_tokens_stream(
_string_to_2_char_stream("Hello there my friend"), ["my"]
)
)
== "Hello there "
)
assert (
_stream_to_string(
_enforce_stop_tokens_stream(
_string_to_2_char_stream("Hello there my friend"), ["my", "there"]
)
)
== "Hello "
)