mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
fireworks[patch]: remove custom async and stream implementations (#18363)
This commit is contained in:
parent
4730ee2766
commit
3c8a115e21
@ -7,10 +7,8 @@ import os
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
@ -26,13 +24,11 @@ from typing import (
|
|||||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
agenerate_from_stream,
|
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
)
|
)
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -57,7 +53,7 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -348,40 +344,6 @@ class ChatFireworks(BaseChatModel):
|
|||||||
combined["system_fingerprint"] = system_fingerprint
|
combined["system_fingerprint"] = system_fingerprint
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
def _stream(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
|
||||||
for chunk in self.client.create(messages=message_dicts, **params):
|
|
||||||
if not isinstance(chunk, dict):
|
|
||||||
chunk = chunk.dict()
|
|
||||||
if len(chunk["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
chunk = _convert_delta_to_message_chunk(
|
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {}
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
logprobs = choice.get("logprobs")
|
|
||||||
if logprobs:
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
default_chunk_class = chunk.__class__
|
|
||||||
chunk = ChatGenerationChunk(
|
|
||||||
message=chunk, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
@ -438,68 +400,6 @@ class ChatFireworks(BaseChatModel):
|
|||||||
}
|
}
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
async def _astream(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
||||||
params = {**params, **kwargs, "stream": True}
|
|
||||||
|
|
||||||
default_chunk_class = AIMessageChunk
|
|
||||||
async for chunk in await self.async_client.create(
|
|
||||||
messages=message_dicts, **params
|
|
||||||
):
|
|
||||||
if not isinstance(chunk, dict):
|
|
||||||
chunk = chunk.dict()
|
|
||||||
if len(chunk["choices"]) == 0:
|
|
||||||
continue
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
chunk = _convert_delta_to_message_chunk(
|
|
||||||
choice["delta"], default_chunk_class
|
|
||||||
)
|
|
||||||
generation_info = {}
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
logprobs = choice.get("logprobs")
|
|
||||||
if logprobs:
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
default_chunk_class = chunk.__class__
|
|
||||||
chunk = ChatGenerationChunk(
|
|
||||||
message=chunk, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
stream: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
should_stream = stream if stream is not None else self.streaming
|
|
||||||
if should_stream:
|
|
||||||
stream_iter = self._astream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await agenerate_from_stream(stream_iter)
|
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
||||||
params = {
|
|
||||||
**params,
|
|
||||||
**({"stream": stream} if stream is not None else {}),
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
response = await self.async_client.create(messages=message_dicts, **params)
|
|
||||||
return self._create_chat_result(response)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
|
2
libs/partners/fireworks/poetry.lock
generated
2
libs/partners/fireworks/poetry.lock
generated
@ -572,7 +572,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.1.27"
|
version = "0.1.28"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-fireworks"
|
name = "langchain-fireworks"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
description = "An integration package connecting Fireworks and LangChain"
|
description = "An integration package connecting Fireworks and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -74,3 +74,64 @@ def test_tool_choice_bool() -> None:
|
|||||||
"name": "Erick",
|
"name": "Erick",
|
||||||
}
|
}
|
||||||
assert tool_call["type"] == "function"
|
assert tool_call["type"] == "function"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream() -> None:
|
||||||
|
"""Test streaming tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream() -> None:
|
||||||
|
"""Test streaming tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch() -> None:
|
||||||
|
"""Test abatch tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch_tags() -> None:
|
||||||
|
"""Test batch tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
result = await llm.abatch(
|
||||||
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch() -> None:
|
||||||
|
"""Test batch tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ainvoke() -> None:
|
||||||
|
"""Test invoke tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke() -> None:
|
||||||
|
"""Test invoke tokens from ChatFireworks."""
|
||||||
|
llm = ChatFireworks()
|
||||||
|
|
||||||
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
|
assert isinstance(result.content, str)
|
||||||
|
@ -39,3 +39,64 @@ async def test_fireworks_acall() -> None:
|
|||||||
output_text = output.generations[0][0].text
|
output_text = output.generations[0][0].text
|
||||||
assert isinstance(output_text, str)
|
assert isinstance(output_text, str)
|
||||||
assert output_text.count("bar") <= 1
|
assert output_text.count("bar") <= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
for token in llm.stream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_astream() -> None:
|
||||||
|
"""Test streaming tokens from OpenAI."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
async for token in llm.astream("I'm Pickle Rick"):
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch() -> None:
|
||||||
|
"""Test streaming tokens from Fireworks."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch_tags() -> None:
|
||||||
|
"""Test batch tokens from Fireworks."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
result = await llm.abatch(
|
||||||
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||||
|
)
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch() -> None:
|
||||||
|
"""Test batch tokens from Fireworks."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
|
for token in result:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ainvoke() -> None:
|
||||||
|
"""Test invoke tokens from Fireworks."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke() -> None:
|
||||||
|
"""Test invoke tokens from Fireworks."""
|
||||||
|
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||||
|
|
||||||
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
|
assert isinstance(result, str)
|
||||||
|
Loading…
Reference in New Issue
Block a user