mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-05 08:40:36 +00:00
Compare commits
9 Commits
feat/tool-
...
erick/part
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1429300ec3 | ||
|
|
8576435bf6 | ||
|
|
84ca98a09c | ||
|
|
d6dc88ed94 | ||
|
|
23b3ca821d | ||
|
|
a5e9df68e1 | ||
|
|
5aec0d6256 | ||
|
|
667ba171fd | ||
|
|
398dbd2334 |
@@ -6,39 +6,127 @@
|
||||
|
||||
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
|
||||
"""
|
||||
from langchain_core.callbacks.base import (
|
||||
AsyncCallbackHandler,
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManagerMixin,
|
||||
Callbacks,
|
||||
ChainManagerMixin,
|
||||
LLMManagerMixin,
|
||||
RetrieverManagerMixin,
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManager,
|
||||
AsyncCallbackManagerForChainGroup,
|
||||
AsyncCallbackManagerForChainRun,
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
AsyncParentRunManager,
|
||||
AsyncRunManager,
|
||||
BaseRunManager,
|
||||
CallbackManager,
|
||||
CallbackManagerForChainGroup,
|
||||
CallbackManagerForChainRun,
|
||||
CallbackManagerForLLMRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
CallbackManagerForToolRun,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "AsyncCallbackHandler":
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
return AsyncCallbackHandler
|
||||
elif name == "BaseCallbackHandler":
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
|
||||
return BaseCallbackHandler
|
||||
elif name == "BaseCallbackManager":
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
|
||||
return BaseCallbackManager
|
||||
elif name == "CallbackManagerMixin":
|
||||
from langchain_core.callbacks.base import CallbackManagerMixin
|
||||
|
||||
return CallbackManagerMixin
|
||||
elif name == "Callbacks":
|
||||
from langchain_core.callbacks.base import Callbacks
|
||||
|
||||
return Callbacks
|
||||
elif name == "ChainManagerMixin":
|
||||
from langchain_core.callbacks.base import ChainManagerMixin
|
||||
|
||||
return ChainManagerMixin
|
||||
elif name == "LLMManagerMixin":
|
||||
from langchain_core.callbacks.base import LLMManagerMixin
|
||||
|
||||
return LLMManagerMixin
|
||||
elif name == "RetrieverManagerMixin":
|
||||
from langchain_core.callbacks.base import RetrieverManagerMixin
|
||||
|
||||
return RetrieverManagerMixin
|
||||
elif name == "ToolManagerMixin":
|
||||
from langchain_core.callbacks.base import ToolManagerMixin
|
||||
|
||||
return ToolManagerMixin
|
||||
elif name == "AsyncCallbackManager":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
return AsyncCallbackManager
|
||||
elif name == "AsyncCallbackManagerForChainGroup":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainGroup
|
||||
|
||||
return AsyncCallbackManagerForChainGroup
|
||||
elif name == "AsyncCallbackManagerForChainRun":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
|
||||
return AsyncCallbackManagerForChainRun
|
||||
elif name == "AsyncCallbackManagerForLLMRun":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||
|
||||
return AsyncCallbackManagerForLLMRun
|
||||
elif name == "AsyncCallbackManagerForRetrieverRun":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
|
||||
|
||||
return AsyncCallbackManagerForRetrieverRun
|
||||
elif name == "AsyncCallbackManagerForToolRun":
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForToolRun
|
||||
|
||||
return AsyncCallbackManagerForToolRun
|
||||
elif name == "AsyncParentRunManager":
|
||||
from langchain_core.callbacks.manager import AsyncParentRunManager
|
||||
|
||||
return AsyncParentRunManager
|
||||
elif name == "AsyncRunManager":
|
||||
from langchain_core.callbacks.manager import AsyncRunManager
|
||||
|
||||
return AsyncRunManager
|
||||
elif name == "BaseRunManager":
|
||||
from langchain_core.callbacks.manager import BaseRunManager
|
||||
|
||||
return BaseRunManager
|
||||
elif name == "CallbackManager":
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
return CallbackManager
|
||||
elif name == "CallbackManagerForChainGroup":
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainGroup
|
||||
|
||||
return CallbackManagerForChainGroup
|
||||
elif name == "CallbackManagerForChainRun":
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
|
||||
return CallbackManagerForChainRun
|
||||
elif name == "CallbackManagerForLLMRun":
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
|
||||
return CallbackManagerForLLMRun
|
||||
elif name == "CallbackManagerForRetrieverRun":
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
|
||||
return CallbackManagerForRetrieverRun
|
||||
elif name == "CallbackManagerForToolRun":
|
||||
from langchain_core.callbacks.manager import CallbackManagerForToolRun
|
||||
|
||||
return CallbackManagerForToolRun
|
||||
elif name == "ParentRunManager":
|
||||
from langchain_core.callbacks.manager import ParentRunManager
|
||||
|
||||
return ParentRunManager
|
||||
elif name == "RunManager":
|
||||
from langchain_core.callbacks.manager import RunManager
|
||||
|
||||
return RunManager
|
||||
elif name == "StdOutCallbackHandler":
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
|
||||
return StdOutCallbackHandler
|
||||
elif name == "StreamingStdOutCallbackHandler":
|
||||
from langchain_core.callbacks.streaming_stdout import (
|
||||
StreamingStdOutCallbackHandler,
|
||||
)
|
||||
|
||||
return StreamingStdOutCallbackHandler
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RetrieverManagerMixin",
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.run_helpers import get_run_tree_context
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain_core.callbacks.base import (
|
||||
@@ -39,7 +38,6 @@ from langchain_core.callbacks.base import (
|
||||
RunManagerMixin,
|
||||
ToolManagerMixin,
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.utils.env import env_var_is_set
|
||||
|
||||
@@ -1910,6 +1908,8 @@ def _configure(
|
||||
Returns:
|
||||
T: The configured callback manager.
|
||||
"""
|
||||
from langsmith.run_helpers import get_run_tree_context
|
||||
|
||||
from langchain_core.tracers.context import (
|
||||
_configure_hooks,
|
||||
_get_tracer_project,
|
||||
@@ -1970,6 +1970,7 @@ def _configure(
|
||||
tracer_project = _get_tracer_project()
|
||||
debug = _get_debug()
|
||||
if verbose or debug or tracing_v2_enabled_:
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ from langchain_core.outputs import (
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -597,18 +596,23 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
|
||||
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
||||
|
||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||
"stream",
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, LogStreamCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False,
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, LogStreamCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
@@ -675,21 +679,26 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
|
||||
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
||||
|
||||
if (
|
||||
type(self)._astream != BaseChatModel._astream
|
||||
or type(self)._stream != BaseChatModel._stream
|
||||
) and kwargs.pop(
|
||||
"stream",
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, LogStreamCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False,
|
||||
(
|
||||
next(
|
||||
(
|
||||
True
|
||||
for h in run_manager.handlers
|
||||
if isinstance(h, LogStreamCallbackHandler)
|
||||
),
|
||||
False,
|
||||
)
|
||||
if run_manager
|
||||
else False
|
||||
),
|
||||
):
|
||||
chunks: List[ChatGenerationChunk] = []
|
||||
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
||||
|
||||
@@ -4,27 +4,96 @@
|
||||
These functions do not depend on any other LangChain module.
|
||||
"""
|
||||
|
||||
from langchain_core.utils import image
|
||||
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
|
||||
from langchain_core.utils.formatting import StrictFormatter, formatter
|
||||
from langchain_core.utils.input import (
|
||||
get_bolded_text,
|
||||
get_color_mapping,
|
||||
get_colored_text,
|
||||
print_text,
|
||||
)
|
||||
from langchain_core.utils.loading import try_load_from_hub
|
||||
from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value
|
||||
from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
check_package_version,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
xor_args,
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "image":
|
||||
import langchain_core.utils.image
|
||||
|
||||
return langchain_core.utils.image
|
||||
elif name == "get_from_dict_or_env":
|
||||
from langchain_core.utils.env import get_from_dict_or_env
|
||||
|
||||
return get_from_dict_or_env
|
||||
elif name == "get_from_env":
|
||||
from langchain_core.utils.env import get_from_env
|
||||
|
||||
return get_from_env
|
||||
elif name == "StrictFormatter":
|
||||
from langchain_core.utils.formatting import StrictFormatter
|
||||
|
||||
return StrictFormatter
|
||||
elif name == "formatter":
|
||||
from langchain_core.utils.formatting import formatter
|
||||
|
||||
return formatter
|
||||
elif name == "get_bolded_text":
|
||||
from langchain_core.utils.input import get_bolded_text
|
||||
|
||||
return get_bolded_text
|
||||
elif name == "get_color_mapping":
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
|
||||
return get_color_mapping
|
||||
elif name == "get_colored_text":
|
||||
from langchain_core.utils.input import get_colored_text
|
||||
|
||||
return get_colored_text
|
||||
elif name == "print_text":
|
||||
from langchain_core.utils.input import print_text
|
||||
|
||||
return print_text
|
||||
elif name == "try_load_from_hub":
|
||||
from langchain_core.utils.loading import try_load_from_hub
|
||||
|
||||
return try_load_from_hub
|
||||
elif name == "comma_list":
|
||||
from langchain_core.utils.strings import comma_list
|
||||
|
||||
return comma_list
|
||||
elif name == "stringify_dict":
|
||||
from langchain_core.utils.strings import stringify_dict
|
||||
|
||||
return stringify_dict
|
||||
elif name == "stringify_value":
|
||||
from langchain_core.utils.strings import stringify_value
|
||||
|
||||
return stringify_value
|
||||
elif name == "build_extra_kwargs":
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
return build_extra_kwargs
|
||||
elif name == "check_package_version":
|
||||
from langchain_core.utils.utils import check_package_version
|
||||
|
||||
return check_package_version
|
||||
elif name == "convert_to_secret_str":
|
||||
from langchain_core.utils.utils import convert_to_secret_str
|
||||
|
||||
return convert_to_secret_str
|
||||
elif name == "get_pydantic_field_names":
|
||||
from langchain_core.utils.utils import get_pydantic_field_names
|
||||
|
||||
return get_pydantic_field_names
|
||||
elif name == "guard_import":
|
||||
from langchain_core.utils.utils import guard_import
|
||||
|
||||
return guard_import
|
||||
elif name == "mock_now":
|
||||
from langchain_core.utils.utils import mock_now
|
||||
|
||||
return mock_now
|
||||
elif name == "raise_for_status_with_text":
|
||||
from langchain_core.utils.utils import raise_for_status_with_text
|
||||
|
||||
return raise_for_status_with_text
|
||||
elif name == "xor_args":
|
||||
from langchain_core.utils.utils import xor_args
|
||||
|
||||
return xor_args
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"StrictFormatter",
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
"""Generic utility functions."""
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import importlib
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from requests import Response
|
||||
|
||||
|
||||
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
"""Validate specified keyword args are mutually exclusive."""
|
||||
@@ -39,8 +42,11 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def raise_for_status_with_text(response: Response) -> None:
|
||||
def raise_for_status_with_text(response: "Response") -> None:
|
||||
"""Raise an error with the response text."""
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
|
||||
@@ -1,4 +1,16 @@
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
||||
from typing import Any
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "ChatMistralAI":
|
||||
from langchain_mistralai.chat_models import ChatMistralAI
|
||||
|
||||
return ChatMistralAI
|
||||
elif name == "MistralAIEmbeddings":
|
||||
from langchain_mistralai.embeddings import MistralAIEmbeddings
|
||||
|
||||
return MistralAIEmbeddings
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = ["ChatMistralAI", "MistralAIEmbeddings"]
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
@@ -18,14 +19,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, aconnect_sse, connect_sse
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@@ -57,6 +51,15 @@ from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import httpx
|
||||
from httpx_sse import EventSource
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -67,8 +70,9 @@ def _create_retry_decorator(
|
||||
] = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
from httpx import RequestError, StreamError
|
||||
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
errors = [RequestError, StreamError]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
@@ -112,6 +116,8 @@ async def acompletion_with_retry(
|
||||
kwargs["stream"] = False
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
from httpx_sse import aconnect_sse
|
||||
|
||||
event_source = aconnect_sse(
|
||||
llm.async_client, "POST", "/chat/completions", json=kwargs
|
||||
)
|
||||
@@ -184,8 +190,8 @@ def _convert_message_to_mistral_chat_message(
|
||||
class ChatMistralAI(BaseChatModel):
|
||||
"""A chat model that uses the MistralAI API."""
|
||||
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
_client: Optional[httpx.Client] = None
|
||||
_async_client: Optional[httpx.AsyncClient] = None
|
||||
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
endpoint: str = "https://api.mistral.ai/v1"
|
||||
max_retries: int = 5
|
||||
@@ -202,6 +208,83 @@ class ChatMistralAI(BaseChatModel):
|
||||
safe_mode: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
|
||||
values["mistral_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||
)
|
||||
)
|
||||
if "client" in values:
|
||||
values["_client"] = values["client"]
|
||||
del values["client"]
|
||||
if "async_client" in values:
|
||||
values["_async_client"] = values["async_client"]
|
||||
del values["async_client"]
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.Client:
|
||||
"""Get the client."""
|
||||
if self._client:
|
||||
return self._client
|
||||
|
||||
from httpx import Client
|
||||
|
||||
# todo: handle retries
|
||||
return Client(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**(
|
||||
{
|
||||
"Authorization": (
|
||||
f"Bearer {self.mistral_api_key.get_secret_value()}"
|
||||
)
|
||||
}
|
||||
if self.mistral_api_key
|
||||
else {}
|
||||
),
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
"""Get the async client."""
|
||||
if self._async_client:
|
||||
return self._async_client
|
||||
# todo: handle retries and max concurrency
|
||||
from httpx import AsyncClient
|
||||
|
||||
return AsyncClient(
|
||||
base_url=self.endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
**(
|
||||
{
|
||||
"Authorization": (
|
||||
f"Bearer {self.mistral_api_key.get_secret_value()}"
|
||||
)
|
||||
}
|
||||
if self.mistral_api_key
|
||||
else {}
|
||||
),
|
||||
},
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
@@ -241,6 +324,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
if stream:
|
||||
|
||||
def iter_sse() -> Iterator[Dict]:
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
with connect_sse(
|
||||
self.client, "POST", "/chat/completions", json=kwargs
|
||||
) as event_source:
|
||||
@@ -272,45 +357,6 @@ class ChatMistralAI(BaseChatModel):
|
||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
return combined
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
|
||||
values["mistral_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||
)
|
||||
)
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
# todo: handle retries
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
# todo: handle retries and max_concurrency
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||
|
||||
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
|
||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
||||
|
||||
return values
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
|
||||
0
misc/memory-benchmarking/chatmistralai/README.md
Normal file
0
misc/memory-benchmarking/chatmistralai/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant."),
|
||||
("human", "What is the capital of {country}?"),
|
||||
]
|
||||
)
|
||||
|
||||
chat = ChatMistralAI()
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
chain.invoke({"country": "Denmark"})
|
||||
File diff suppressed because one or more lines are too long
1156
misc/memory-benchmarking/chatmistralai/poetry.lock
generated
Normal file
1156
misc/memory-benchmarking/chatmistralai/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
18
misc/memory-benchmarking/chatmistralai/pyproject.toml
Normal file
18
misc/memory-benchmarking/chatmistralai/pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[tool.poetry]
|
||||
name = "chatmistralai"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Erick Friis <erick@langchain.dev>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
langchain-mistralai = {path = "../../../libs/partners/mistralai", develop = true}
|
||||
langchain-core = {path = "../../../libs/core", develop = true}
|
||||
memray = "^1.11.0"
|
||||
langsmith = {path = "/Users/erickfriis/langchain/smith-sdk/python", develop = true}
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
0
misc/memory-benchmarking/chatopenai/README.md
Normal file
0
misc/memory-benchmarking/chatopenai/README.md
Normal file
15
misc/memory-benchmarking/chatopenai/chatopenai/chatopenai.py
Normal file
15
misc/memory-benchmarking/chatopenai/chatopenai/chatopenai.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are a helpful assistant."),
|
||||
("human", "What is the capital of {country}?"),
|
||||
]
|
||||
)
|
||||
|
||||
chat = ChatOpenAI()
|
||||
|
||||
chain = prompt | chat
|
||||
|
||||
chain.invoke({"country": "Denmark"})
|
||||
File diff suppressed because one or more lines are too long
1118
misc/memory-benchmarking/chatopenai/poetry.lock
generated
Normal file
1118
misc/memory-benchmarking/chatopenai/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
17
misc/memory-benchmarking/chatopenai/pyproject.toml
Normal file
17
misc/memory-benchmarking/chatopenai/pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[tool.poetry]
|
||||
name = "chatopenai"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Erick Friis <erick@langchain.dev>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
langchain-core = {path = "../../../libs/core", develop = true}
|
||||
langchain-openai = {path = "../../../libs/partners/openai", develop = true}
|
||||
memray = "^1.11.0"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
Reference in New Issue
Block a user