mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
fireworks[patch]: Add Fireworks partner packages (#17694)
--------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
11cf95e810
commit
ee6a773456
1
libs/partners/fireworks/.gitignore
vendored
Normal file
1
libs/partners/fireworks/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/fireworks/LICENSE
Normal file
21
libs/partners/fireworks/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
59
libs/partners/fireworks/Makefile
Normal file
59
libs/partners/fireworks/Makefile
Normal file
@@ -0,0 +1,59 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/fireworks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_fireworks
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_fireworks -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
15
libs/partners/fireworks/README.md
Normal file
15
libs/partners/fireworks/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# LangChain-Fireworks
|
||||
|
||||
This is the partner package for tying Fireworks.ai and LangChain. Fireworks really strive to provide good support for LangChain use cases, so if you run into any issues please let us know. You can reach out to us [in our Discord channel](https://discord.com/channels/1137072072808472616/)
|
||||
|
||||
## Basic LangChain-Fireworks example
|
||||
|
||||
|
||||
## Advanced
|
||||
### Tool use: LangChain Agent + Fireworks function calling model
|
||||
Please checkout how to teach Fireworks function calling model to use a [calculator here](https://github.com/fw-ai/cookbook/blob/main/examples/function_calling/fireworks_langchain_tool_usage.ipynb).
|
||||
|
||||
Fireworks focus on delivering the best experience for fast model inference as well as tool use. You can check out [our blog](https://fireworks.ai/blog/firefunction-v1-gpt-4-level-function-calling) for more details on how it fares compares to GPT-4, the punchline is that it is on par with GPT-4 in terms just function calling use cases, but it is way faster and much cheaper.
|
||||
|
||||
### RAG: LangChain agent + Fireworks function calling model + MongoDB + Nomic AI embeddings
|
||||
Please check out the [cookbook here](https://github.com/fw-ai/cookbook/blob/main/examples/rag/mongodb_agent.ipynb) for an end to end flow
|
11
libs/partners/fireworks/langchain_fireworks/__init__.py
Normal file
11
libs/partners/fireworks/langchain_fireworks/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
from langchain_fireworks.llms import Fireworks
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"ChatFireworks",
|
||||
"Fireworks",
|
||||
"FireworksEmbeddings",
|
||||
]
|
615
libs/partners/fireworks/langchain_fireworks/chat_models.py
Normal file
615
libs/partners/fireworks/langchain_fireworks/chat_models.py
Normal file
@@ -0,0 +1,615 @@
|
||||
"""Fireworks chat wrapper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
Args:
|
||||
_dict: The dictionary.
|
||||
|
||||
Returns:
|
||||
The LangChain message.
|
||||
"""
|
||||
role = _dict.get("role")
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict.get("content", ""))
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also Fireworks returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if function_call := _dict.get("function_call"):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
if tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = tool_calls
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict.get("content", ""))
|
||||
elif role == "function":
|
||||
return FunctionMessage(
|
||||
content=_dict.get("content", ""), name=_dict.get("name", "")
|
||||
)
|
||||
elif role == "tool":
|
||||
additional_kwargs = {}
|
||||
if "name" in _dict:
|
||||
additional_kwargs["name"] = _dict["name"]
|
||||
return ToolMessage(
|
||||
content=_dict.get("content", ""),
|
||||
tool_call_id=_dict.get("tool_call_id", ""),
|
||||
additional_kwargs=additional_kwargs,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=_dict.get("content", ""), role=role or "")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
||||
# If tool calls only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = cast(str, _dict.get("role"))
|
||||
content = cast(str, _dict.get("content") or "")
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
if _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role == "tool" or default_class == ToolMessageChunk:
|
||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content) # type: ignore
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
# This is basically a copy and replace for ChatOpenAI, except
|
||||
# - I needed to gut out tiktoken and some of the token estimation logic
|
||||
# (not sure how important it is)
|
||||
# - Environment variable is different
|
||||
# we should refactor into some OpenAI-like class in the future
|
||||
class ChatFireworks(BaseChatModel):
|
||||
"""`Fireworks` Chat large language models API.
|
||||
|
||||
To use, you should have the
|
||||
environment variable ``FIREWORKS_API_KEY`` set with your API key.
|
||||
|
||||
Any parameters that are valid to be passed to the fireworks.create call
|
||||
can be passed in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_fireworks.chat_models import ChatFireworks
|
||||
fireworks = ChatFireworks(
|
||||
model_name="accounts/fireworks/models/mixtral-8x7b-instruct")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "chat_models", "fireworks"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if self.fireworks_api_base:
|
||||
attributes["fireworks_api_base"] = self.fireworks_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model_name: str = Field(
|
||||
default="accounts/fireworks/models/mixtral-8x7b-instruct", alias="model"
|
||||
)
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.0
|
||||
"""What sampling temperature to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `FIREWORKS_API_KEY` if not provided."""
|
||||
fireworks_api_base: Optional[str] = Field(default=None, alias="base_url")
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Fireworks completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
n: int = 1
|
||||
"""Number of chat completions to generate for each prompt."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["fireworks_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
values["fireworks_api_base"] = values["fireworks_api_base"] or os.getenv(
|
||||
"FIREWORKS_API_BASE"
|
||||
)
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["fireworks_api_key"].get_secret_value()
|
||||
if values["fireworks_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["fireworks_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
values["client"] = Fireworks(**client_params).chat.completions
|
||||
if not values.get("async_client"):
|
||||
values["async_client"] = AsyncFireworks(**client_params).chat.completions
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Fireworks API."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"stream": self.streaming,
|
||||
"n": self.n,
|
||||
"temperature": self.temperature,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
return params
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
system_fingerprint = None
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
if token_usage is not None:
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
if system_fingerprint is None:
|
||||
system_fingerprint = output.get("system_fingerprint")
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||
if system_fingerprint:
|
||||
combined["system_fingerprint"] = system_fingerprint
|
||||
return combined
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
for chunk in self.client.create(messages=message_dicts, **params):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
||||
yield chunk
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = self.client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._default_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
|
||||
generations = []
|
||||
if not isinstance(response, dict):
|
||||
response = response.dict()
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
||||
if "logprobs" in res:
|
||||
generation_info["logprobs"] = res["logprobs"]
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=generation_info,
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response.get("usage", {})
|
||||
llm_output = {
|
||||
"token_usage": token_usage,
|
||||
"model_name": self.model_name,
|
||||
"system_fingerprint": response.get("system_fingerprint", ""),
|
||||
}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
async for chunk in await self.async_client.create(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {
|
||||
**params,
|
||||
**({"stream": stream} if stream is not None else {}),
|
||||
**kwargs,
|
||||
}
|
||||
response = await self.async_client.create(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"model_name": self.model_name, **self._default_params}
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
return {
|
||||
"model": self.model_name,
|
||||
**super()._get_invocation_params(stop=stop),
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "fireworks-chat"
|
||||
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind functions (and other objects) to this chat model.
|
||||
|
||||
Assumes model is compatible with Fireworks function-calling API.
|
||||
|
||||
NOTE: Using bind_tools is recommended instead, as the `functions` and
|
||||
`function_call` request parameters are officially marked as deprecated by
|
||||
Fireworks.
|
||||
|
||||
Args:
|
||||
functions: A list of function definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, or callable. Pydantic
|
||||
models and callables will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
function_call: Which function to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any).
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
if function_call is not None:
|
||||
function_call = (
|
||||
{"name": function_call}
|
||||
if isinstance(function_call, str)
|
||||
and function_call not in ("auto", "none")
|
||||
else function_call
|
||||
)
|
||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and formatted_functions[0]["name"] != function_call["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
kwargs = {**kwargs, "function_call": function_call}
|
||||
return super().bind(
|
||||
functions=formatted_functions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "none"]]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with Fireworks tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
"auto" to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice is not None:
|
||||
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice}}
|
||||
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
|
||||
raise ValueError(
|
||||
"When specifying `tool_choice`, you must provide exactly one "
|
||||
f"tool. Received {len(formatted_tools)} tools."
|
||||
)
|
||||
if isinstance(tool_choice, dict) and (
|
||||
formatted_tools[0]["function"]["name"]
|
||||
!= tool_choice["function"]["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tool was {formatted_tools[0]['function']['name']}."
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
File diff suppressed because one or more lines are too long
52
libs/partners/fireworks/langchain_fireworks/embeddings.py
Normal file
52
libs/partners/fireworks/langchain_fireworks/embeddings.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from openai import OpenAI # type: ignore
|
||||
|
||||
|
||||
class FireworksEmbeddings(BaseModel, Embeddings):
|
||||
"""FireworksEmbeddings embedding model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
|
||||
model = FireworksEmbeddings(
|
||||
model='nomic-ai/nomic-embed-text-v1.5'
|
||||
)
|
||||
"""
|
||||
|
||||
_client: OpenAI = Field(default=None)
|
||||
fireworks_api_key: SecretStr = convert_to_secret_str("")
|
||||
model: str = "nomic-ai/nomic-embed-text-v1.5"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate environment variables."""
|
||||
fireworks_api_key = convert_to_secret_str(
|
||||
values.get("fireworks_api_key") or os.getenv("FIREWORKS_API_KEY") or ""
|
||||
)
|
||||
values["fireworks_api_key"] = fireworks_api_key
|
||||
|
||||
# note this sets it globally for module
|
||||
# there isn't currently a way to pass it into client
|
||||
api_key = fireworks_api_key.get_secret_value()
|
||||
values["_client"] = OpenAI(
|
||||
api_key=api_key, base_url="https://api.fireworks.ai/inference/v1"
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
return [
|
||||
i.embedding
|
||||
for i in self._client.embeddings.create(input=texts, model=self.model).data
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text."""
|
||||
return self.embed_documents([text])[0]
|
222
libs/partners/fireworks/langchain_fireworks/llms.py
Normal file
222
libs/partners/fireworks/langchain_fireworks/llms.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Wrapper around Fireworks AI's Completion API."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
from langchain_fireworks.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Fireworks(LLM):
|
||||
"""LLM models from `Fireworks`.
|
||||
|
||||
To use, you'll need an API key which you can find here:
|
||||
https://fireworks.ai This can be passed in as init param
|
||||
``fireworks_api_key`` or set as environment variable ``FIREWORKS_API_KEY``.
|
||||
|
||||
Fireworks AI API reference: https://readme.fireworks.ai/
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = fireworks.generate(["Tell me a joke."])
|
||||
"""
|
||||
|
||||
base_url: str = "https://api.fireworks.ai/inference/v1/completions"
|
||||
"""Base inference API URL."""
|
||||
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
|
||||
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
https://readme.fireworks.ai/
|
||||
"""
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature."""
|
||||
top_p: Optional[float] = None
|
||||
"""Used to dynamically adjust the number of choices for each predicted token based
|
||||
on the cumulative probabilities. A value of 1 will always yield the same
|
||||
output. A temperature less than 1 favors more correctness and is appropriate
|
||||
for question answering or summarization. A value greater than 1 introduces more
|
||||
randomness in the output.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
top_k: Optional[int] = None
|
||||
"""Used to limit the number of choices for the next predicted word or token. It
|
||||
specifies the maximum number of tokens to consider at each step, based on their
|
||||
probability of occurrence. This technique helps to speed up the generation
|
||||
process and can improve the quality of the generated text by focusing on the
|
||||
most likely options.
|
||||
"""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
repetition_penalty: Optional[float] = None
|
||||
"""A number that controls the diversity of generated text by reducing the
|
||||
likelihood of repeated sequences. Higher values decrease repetition.
|
||||
"""
|
||||
logprobs: Optional[int] = None
|
||||
"""An integer that specifies how many top token log probabilities are included in
|
||||
the response for each token generation step.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["fireworks_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "fireworks"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
return output["choices"][0]["text"]
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent() -> str:
|
||||
return f"langchain-fireworks/{__version__}"
|
||||
|
||||
@property
|
||||
def default_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"max_tokens": self.max_tokens,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to Fireworks's text generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model..
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
response = requests.post(url=self.base_url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"Fireworks received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
output = self._format_output(data)
|
||||
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call Fireworks model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.fireworks_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
stop_to_use = stop[0] if stop and len(stop) == 1 else stop
|
||||
payload: Dict[str, Any] = {
|
||||
**self.default_params,
|
||||
"prompt": prompt,
|
||||
"stop": stop_to_use,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# filter None values to not pass them to the http payload
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
async with ClientSession() as session:
|
||||
async with session.post(
|
||||
self.base_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"Fireworks Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"Fireworks received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"Fireworks returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
if response_json.get("status") != "finished":
|
||||
err_msg = response_json.get("error", "Undefined Error")
|
||||
raise Exception(err_msg)
|
||||
|
||||
output = self._format_output(response_json)
|
||||
return output
|
8
libs/partners/fireworks/langchain_fireworks/version.py
Normal file
8
libs/partners/fireworks/langchain_fireworks/version.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Main entrypoint into package."""
|
||||
from importlib import metadata
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
1541
libs/partners/fireworks/poetry.lock
generated
Normal file
1541
libs/partners/fireworks/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
97
libs/partners/fireworks/pyproject.toml
Normal file
97
libs/partners/fireworks/pyproject.toml
Normal file
@@ -0,0 +1,97 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-fireworks"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Fireworks and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/fireworks"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.16"
|
||||
fireworks-ai = ">=0.12.0,<1"
|
||||
openai = "^1.10.0"
|
||||
requests = "^2"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.2.2"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
types-requests = "^2"
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/fireworks/scripts/check_imports.py
Normal file
17
libs/partners/fireworks/scripts/check_imports.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/fireworks/scripts/check_pydantic.sh
Executable file
27
libs/partners/fireworks/scripts/check_pydantic.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
17
libs/partners/fireworks/scripts/lint_imports.sh
Executable file
17
libs/partners/fireworks/scripts/lint_imports.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/fireworks/tests/__init__.py
Normal file
0
libs/partners/fireworks/tests/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@@ -0,0 +1,20 @@
|
||||
"""Test Fireworks embeddings."""
|
||||
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
|
||||
|
||||
def test_langchain_fireworks_embedding_documents() -> None:
|
||||
"""Test Fireworks hosted embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_fireworks_embedding_query() -> None:
|
||||
"""Test Fireworks hosted embeddings."""
|
||||
document = "foo bar"
|
||||
embedding = FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
41
libs/partners/fireworks/tests/integration_tests/test_llms.py
Normal file
41
libs/partners/fireworks/tests/integration_tests/test_llms.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Test Fireworks API wrapper.
|
||||
|
||||
In order to run this test, you need to have an Fireworks api key.
|
||||
You can get it by registering for free at https://api.fireworks.ai/.
|
||||
A test key can be found at https://api.fireworks.ai/settings/api-keys
|
||||
|
||||
You'll then need to set FIREWORKS_API_KEY environment variable to your api key.
|
||||
"""
|
||||
|
||||
import pytest as pytest
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
|
||||
def test_fireworks_call() -> None:
|
||||
"""Test simple call to fireworks."""
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
output = llm.invoke("Say foo:")
|
||||
|
||||
assert llm._llm_type == "fireworks"
|
||||
assert isinstance(output, str)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_fireworks_acall() -> None:
|
||||
"""Test simple call to fireworks."""
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
output = await llm.agenerate(["Say foo:"], stop=["bar"])
|
||||
|
||||
assert llm._llm_type == "fireworks"
|
||||
output_text = output.generations[0][0].text
|
||||
assert isinstance(output_text, str)
|
||||
assert output_text.count("bar") <= 1
|
@@ -0,0 +1,8 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from langchain_fireworks.embeddings import FireworksEmbeddings
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
FireworksEmbeddings(model="nomic-ai/nomic-embed-text-v1.5")
|
12
libs/partners/fireworks/tests/unit_tests/test_imports.py
Normal file
12
libs/partners/fireworks/tests/unit_tests/test_imports.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from langchain_fireworks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"__version__",
|
||||
"ChatFireworks",
|
||||
"Fireworks",
|
||||
"FireworksEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
62
libs/partners/fireworks/tests/unit_tests/test_llms.py
Normal file
62
libs/partners/fireworks/tests/unit_tests/test_llms.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Test Fireworks LLM"""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
|
||||
def test_fireworks_api_key_is_secret_string() -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
assert isinstance(llm.fireworks_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_fireworks_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed from an environment variable."""
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "secret-api-key")
|
||||
llm = Fireworks(
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
print(llm.fireworks_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_fireworks_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed via the constructor."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
print(llm.fireworks_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_fireworks_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
llm = Fireworks(
|
||||
fireworks_api_key="secret-api-key",
|
||||
model="accounts/fireworks/models/mixtral-8x7b-instruct",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
assert cast(SecretStr, llm.fireworks_api_key).get_secret_value() == "secret-api-key"
|
Reference in New Issue
Block a user