From 3c8a115e21d8e3c4792809be9b7bd9e354e9ea8b Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 29 Feb 2024 19:20:02 -0800 Subject: [PATCH] fireworks[patch]: remove custom async and stream implementations (#18363) --- .../langchain_fireworks/chat_models.py | 102 +----------------- libs/partners/fireworks/poetry.lock | 2 +- libs/partners/fireworks/pyproject.toml | 2 +- .../integration_tests/test_chat_models.py | 61 +++++++++++ .../tests/integration_tests/test_llms.py | 61 +++++++++++ 5 files changed, 125 insertions(+), 103 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 8844e379c3f..e30257d5bf4 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -7,10 +7,8 @@ import os from operator import itemgetter from typing import ( Any, - AsyncIterator, Callable, Dict, - Iterator, List, Literal, Mapping, @@ -26,13 +24,11 @@ from typing import ( from fireworks.client import AsyncFireworks, Fireworks # type: ignore from langchain_core._api import beta from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, - agenerate_from_stream, generate_from_stream, ) from langchain_core.messages import ( @@ -57,7 +53,7 @@ from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, PydanticToolsParser, ) -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool @@ -348,40 +344,6 @@ class ChatFireworks(BaseChatModel): combined["system_fingerprint"] = system_fingerprint return combined - def _stream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class = AIMessageChunk - for chunk in self.client.create(messages=message_dicts, **params): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) - if run_manager: - run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs) - yield chunk - def _generate( self, messages: List[BaseMessage], @@ -438,68 +400,6 @@ class ChatFireworks(BaseChatModel): } return ChatResult(generations=generations, llm_output=llm_output) - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - - default_chunk_class = AIMessageChunk - async for chunk in await self.async_client.create( - messages=message_dicts, **params - ): - if not isinstance(chunk, dict): - chunk = chunk.dict() - if len(chunk["choices"]) == 0: - continue - choice = chunk["choices"][0] - chunk = _convert_delta_to_message_chunk( - choice["delta"], default_chunk_class - ) - generation_info = {} - if finish_reason := choice.get("finish_reason"): - generation_info["finish_reason"] = finish_reason - logprobs = choice.get("logprobs") - if logprobs: - generation_info["logprobs"] = logprobs - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk( - message=chunk, generation_info=generation_info or None - ) - if run_manager: - await run_manager.on_llm_new_token( - token=chunk.text, chunk=chunk, logprobs=logprobs - ) - yield chunk - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._astream( - messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await agenerate_from_stream(stream_iter) - - message_dicts, params = self._create_message_dicts(messages, stop) - params = { - **params, - **({"stream": stream} if stream is not None else {}), - **kwargs, - } - response = await self.async_client.create(messages=message_dicts, **params) - return self._create_chat_result(response) - @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" diff --git a/libs/partners/fireworks/poetry.lock b/libs/partners/fireworks/poetry.lock index 50a1b7be1fe..6a70c547613 100644 --- a/libs/partners/fireworks/poetry.lock +++ b/libs/partners/fireworks/poetry.lock @@ -572,7 +572,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.27" +version = "0.1.28" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" diff --git a/libs/partners/fireworks/pyproject.toml b/libs/partners/fireworks/pyproject.toml index 9bc42aa7486..c2899a9486d 100644 --- a/libs/partners/fireworks/pyproject.toml +++ b/libs/partners/fireworks/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-fireworks" -version = "0.1.0" +version = "0.1.1" description = "An integration package connecting Fireworks and LangChain" authors = [] readme = "README.md" diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 1773173a2e7..27c38b29f1e 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -74,3 +74,64 @@ def test_tool_choice_bool() -> None: "name": "Erick", } assert tool_call["type"] == "function" + + +def test_stream() -> None: + """Test streaming tokens from ChatFireworks.""" + llm = ChatFireworks() + + for token in llm.stream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +async def test_astream() -> None: + """Test streaming tokens from ChatFireworks.""" + llm = ChatFireworks() + + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token.content, str) + + +async def test_abatch() -> None: + """Test abatch tokens from ChatFireworks.""" + llm = ChatFireworks() + + result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +async def test_abatch_tags() -> None: + """Test batch tokens from ChatFireworks.""" + llm = ChatFireworks() + + result = await llm.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token.content, str) + + +def test_batch() -> None: + """Test batch tokens from ChatFireworks.""" + llm = ChatFireworks() + + result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token.content, str) + + +async def test_ainvoke() -> None: + """Test invoke tokens from ChatFireworks.""" + llm = ChatFireworks() + + result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) + assert isinstance(result.content, str) + + +def test_invoke() -> None: + """Test invoke tokens from ChatFireworks.""" + llm = ChatFireworks() + + result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + assert isinstance(result.content, str) diff --git a/libs/partners/fireworks/tests/integration_tests/test_llms.py b/libs/partners/fireworks/tests/integration_tests/test_llms.py index b22e7ef10f0..4ba2d065c9b 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_llms.py +++ b/libs/partners/fireworks/tests/integration_tests/test_llms.py @@ -39,3 +39,64 @@ async def test_fireworks_acall() -> None: output_text = output.generations[0][0].text assert isinstance(output_text, str) assert output_text.count("bar") <= 1 + + +def test_stream() -> None: + """Test streaming tokens from OpenAI.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + for token in llm.stream("I'm Pickle Rick"): + assert isinstance(token, str) + + +async def test_astream() -> None: + """Test streaming tokens from OpenAI.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + async for token in llm.astream("I'm Pickle Rick"): + assert isinstance(token, str) + + +async def test_abatch() -> None: + """Test streaming tokens from Fireworks.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +async def test_abatch_tags() -> None: + """Test batch tokens from Fireworks.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + result = await llm.abatch( + ["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]} + ) + for token in result: + assert isinstance(token, str) + + +def test_batch() -> None: + """Test batch tokens from Fireworks.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"]) + for token in result: + assert isinstance(token, str) + + +async def test_ainvoke() -> None: + """Test invoke tokens from Fireworks.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]}) + assert isinstance(result, str) + + +def test_invoke() -> None: + """Test invoke tokens from Fireworks.""" + llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct") + + result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) + assert isinstance(result, str)