mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
fireworks[patch]: Fix fireworks async stream (#18372)
- **Description:** Fix the async stream issue with Fireworks - **Dependencies:** fireworks >= 0.13.0 ``` tests/integration_tests/test_chat_models.py .......... [ 45%] tests/integration_tests/test_compile.py . [ 50%] tests/integration_tests/test_embeddings.py .. [ 59%] tests/integration_tests/test_llms.py ......... [100%] ``` ``` tests/unit_tests/test_embeddings.py . [ 16%] tests/unit_tests/test_imports.py . [ 33%] tests/unit_tests/test_llms.py .... [100%] ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
1deb8cadd5
commit
2b93206f02
@ -5,11 +5,9 @@ all: help
|
|||||||
|
|
||||||
# Define a variable for the test file path.
|
# Define a variable for the test file path.
|
||||||
TEST_FILE ?= tests/unit_tests/
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
integration_test integration_tests: TEST_FILE ?= tests/integration_tests/
|
||||||
|
|
||||||
test:
|
test tests integration_test integration_tests:
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,10 @@ import os
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
@ -24,11 +26,13 @@ from typing import (
|
|||||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -53,7 +57,7 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -344,6 +348,40 @@ class ChatFireworks(BaseChatModel):
|
|||||||
combined["system_fingerprint"] = system_fingerprint
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
for chunk in self.client.create(messages=message_dicts, **params):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.dict()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -400,6 +438,66 @@ class ChatFireworks(BaseChatModel):
|
|||||||
}
|
}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in self.async_client.acreate(messages=message_dicts, **params):
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.dict()
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
should_stream = stream if stream is not None else self.streaming
|
||||||
|
if should_stream:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
**({"stream": stream} if stream is not None else {}),
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
response = await self.async_client.acreate(messages=message_dicts, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
@ -213,10 +213,5 @@ class Fireworks(LLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
|
|
||||||
if response_json.get("status") != "finished":
|
|
||||||
err_msg = response_json.get("error", "Undefined Error")
|
|
||||||
raise Exception(err_msg)
|
|
||||||
|
|
||||||
output = self._format_output(response_json)
|
output = self._format_output(response_json)
|
||||||
return output
|
return output
|
||||||
|
14
libs/partners/fireworks/poetry.lock
generated
14
libs/partners/fireworks/poetry.lock
generated
@ -341,13 +341,13 @@ test = ["pytest (>=6)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fireworks-ai"
|
name = "fireworks-ai"
|
||||||
version = "0.12.1"
|
version = "0.13.0"
|
||||||
description = "Python client library for the Fireworks.ai Generative AI Platform"
|
description = "Python client library for the Fireworks.ai Generative AI Platform"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "fireworks-ai-0.12.1.tar.gz", hash = "sha256:77a3b3be243182548cb5f690f60528f09ca2a7223b871e47fc4e9d13a0df5c1b"},
|
{file = "fireworks-ai-0.13.0.tar.gz", hash = "sha256:d6db1e60f65f237b6e87e3e9c028681be0ba77496df398db386ae2876dab54e0"},
|
||||||
{file = "fireworks_ai-0.12.1-py3-none-any.whl", hash = "sha256:f78dc61f46c534ba045ad111fc7eeed6ea7ff022e7dce446dd23f56ebad371e7"},
|
{file = "fireworks_ai-0.13.0-py3-none-any.whl", hash = "sha256:900559d7eeea8a86dc5789f9034b3873684a685a2e96b56a63e6be3a04803eb6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1149,13 +1149,13 @@ watchdog = ">=2.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.9.0"
|
||||||
description = "Extensions to the standard Python datetime module"
|
description = "Extensions to the standard Python datetime module"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
{file = "python-dateutil-2.9.0.tar.gz", hash = "sha256:78e73e19c63f5b20ffa567001531680d939dc042bf7850431877645523c66709"},
|
||||||
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
{file = "python_dateutil-2.9.0-py2.py3-none-any.whl", hash = "sha256:cbf2f1da5e6083ac2fbfd4da39a25f34312230110440f424a14c7558bb85d82e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1538,4 +1538,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "b9ee2bfb5053127cb29b8182baea395897744d7b5e8b985c42863133a26708ba"
|
content-hash = "ab5538b63e5d347dadcad268e135a5ca9fb5bc2edd2436dcee99c55a7ee4b609"
|
||||||
|
@ -13,7 +13,7 @@ license = "MIT"
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = "^0.1.27"
|
langchain-core = "^0.1.27"
|
||||||
fireworks-ai = ">=0.12.0,<1"
|
fireworks-ai = ">=0.13.0"
|
||||||
openai = "^1.10.0"
|
openai = "^1.10.0"
|
||||||
requests = "^2"
|
requests = "^2"
|
||||||
aiohttp = "^3.9.1"
|
aiohttp = "^3.9.1"
|
||||||
|
Loading…
Reference in New Issue
Block a user