Compare commits

...

9 Commits

Author SHA1 Message Date
Erick Friis
1429300ec3 x 2024-04-08 11:52:09 -07:00
Erick Friis
8576435bf6 x 2024-04-08 11:51:32 -07:00
Erick Friis
84ca98a09c x 2024-04-08 11:49:56 -07:00
Erick Friis
d6dc88ed94 x 2024-04-08 11:49:00 -07:00
Erick Friis
23b3ca821d Merge branch 'master' into erick/partner-cloudflare 2024-04-04 14:57:19 -07:00
Erick Friis
a5e9df68e1 x 2024-04-02 15:42:28 -07:00
Erick Friis
5aec0d6256 merge 2024-04-02 15:21:55 -07:00
Erick Friis
667ba171fd x 2024-04-02 15:21:26 -07:00
Erick Friis
398dbd2334 initpkg 2024-04-01 17:45:12 -07:00
21 changed files with 3377 additions and 131 deletions

View File

@@ -6,39 +6,127 @@
BaseCallbackHandler --> <name>CallbackHandler # Example: AimCallbackHandler
"""
from langchain_core.callbacks.base import (
AsyncCallbackHandler,
BaseCallbackHandler,
BaseCallbackManager,
CallbackManagerMixin,
Callbacks,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
from langchain_core.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainGroup,
AsyncCallbackManagerForChainRun,
AsyncCallbackManagerForLLMRun,
AsyncCallbackManagerForRetrieverRun,
AsyncCallbackManagerForToolRun,
AsyncParentRunManager,
AsyncRunManager,
BaseRunManager,
CallbackManager,
CallbackManagerForChainGroup,
CallbackManagerForChainRun,
CallbackManagerForLLMRun,
CallbackManagerForRetrieverRun,
CallbackManagerForToolRun,
ParentRunManager,
RunManager,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from typing import Any
def __getattr__(name: str) -> Any:
if name == "AsyncCallbackHandler":
from langchain_core.callbacks.base import AsyncCallbackHandler
return AsyncCallbackHandler
elif name == "BaseCallbackHandler":
from langchain_core.callbacks.base import BaseCallbackHandler
return BaseCallbackHandler
elif name == "BaseCallbackManager":
from langchain_core.callbacks.base import BaseCallbackManager
return BaseCallbackManager
elif name == "CallbackManagerMixin":
from langchain_core.callbacks.base import CallbackManagerMixin
return CallbackManagerMixin
elif name == "Callbacks":
from langchain_core.callbacks.base import Callbacks
return Callbacks
elif name == "ChainManagerMixin":
from langchain_core.callbacks.base import ChainManagerMixin
return ChainManagerMixin
elif name == "LLMManagerMixin":
from langchain_core.callbacks.base import LLMManagerMixin
return LLMManagerMixin
elif name == "RetrieverManagerMixin":
from langchain_core.callbacks.base import RetrieverManagerMixin
return RetrieverManagerMixin
elif name == "ToolManagerMixin":
from langchain_core.callbacks.base import ToolManagerMixin
return ToolManagerMixin
elif name == "AsyncCallbackManager":
from langchain_core.callbacks.manager import AsyncCallbackManager
return AsyncCallbackManager
elif name == "AsyncCallbackManagerForChainGroup":
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainGroup
return AsyncCallbackManagerForChainGroup
elif name == "AsyncCallbackManagerForChainRun":
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
return AsyncCallbackManagerForChainRun
elif name == "AsyncCallbackManagerForLLMRun":
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
return AsyncCallbackManagerForLLMRun
elif name == "AsyncCallbackManagerForRetrieverRun":
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
return AsyncCallbackManagerForRetrieverRun
elif name == "AsyncCallbackManagerForToolRun":
from langchain_core.callbacks.manager import AsyncCallbackManagerForToolRun
return AsyncCallbackManagerForToolRun
elif name == "AsyncParentRunManager":
from langchain_core.callbacks.manager import AsyncParentRunManager
return AsyncParentRunManager
elif name == "AsyncRunManager":
from langchain_core.callbacks.manager import AsyncRunManager
return AsyncRunManager
elif name == "BaseRunManager":
from langchain_core.callbacks.manager import BaseRunManager
return BaseRunManager
elif name == "CallbackManager":
from langchain_core.callbacks.manager import CallbackManager
return CallbackManager
elif name == "CallbackManagerForChainGroup":
from langchain_core.callbacks.manager import CallbackManagerForChainGroup
return CallbackManagerForChainGroup
elif name == "CallbackManagerForChainRun":
from langchain_core.callbacks.manager import CallbackManagerForChainRun
return CallbackManagerForChainRun
elif name == "CallbackManagerForLLMRun":
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
return CallbackManagerForLLMRun
elif name == "CallbackManagerForRetrieverRun":
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
return CallbackManagerForRetrieverRun
elif name == "CallbackManagerForToolRun":
from langchain_core.callbacks.manager import CallbackManagerForToolRun
return CallbackManagerForToolRun
elif name == "ParentRunManager":
from langchain_core.callbacks.manager import ParentRunManager
return ParentRunManager
elif name == "RunManager":
from langchain_core.callbacks.manager import RunManager
return RunManager
elif name == "StdOutCallbackHandler":
from langchain_core.callbacks.stdout import StdOutCallbackHandler
return StdOutCallbackHandler
elif name == "StreamingStdOutCallbackHandler":
from langchain_core.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler,
)
return StreamingStdOutCallbackHandler
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"RetrieverManagerMixin",

View File

@@ -26,7 +26,6 @@ from typing import (
)
from uuid import UUID
from langsmith.run_helpers import get_run_tree_context
from tenacity import RetryCallState
from langchain_core.callbacks.base import (
@@ -39,7 +38,6 @@ from langchain_core.callbacks.base import (
RunManagerMixin,
ToolManagerMixin,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.utils.env import env_var_is_set
@@ -1910,6 +1908,8 @@ def _configure(
Returns:
T: The configured callback manager.
"""
from langsmith.run_helpers import get_run_tree_context
from langchain_core.tracers.context import (
_configure_hooks,
_get_tracer_project,
@@ -1970,6 +1970,7 @@ def _configure(
tracer_project = _get_tracer_project()
debug = _get_debug()
if verbose or debug or tracing_v2_enabled_:
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.stdout import ConsoleCallbackHandler

View File

@@ -50,7 +50,6 @@ from langchain_core.outputs import (
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
@@ -597,18 +596,23 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
"stream",
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False,
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False
),
):
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
@@ -675,21 +679,26 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
if (
type(self)._astream != BaseChatModel._astream
or type(self)._stream != BaseChatModel._stream
) and kwargs.pop(
"stream",
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False,
(
next(
(
True
for h in run_manager.handlers
if isinstance(h, LogStreamCallbackHandler)
),
False,
)
if run_manager
else False
),
):
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):

View File

@@ -4,27 +4,96 @@
These functions do not depend on any other LangChain module.
"""
from langchain_core.utils import image
from langchain_core.utils.env import get_from_dict_or_env, get_from_env
from langchain_core.utils.formatting import StrictFormatter, formatter
from langchain_core.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
from langchain_core.utils.loading import try_load_from_hub
from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value
from langchain_core.utils.utils import (
build_extra_kwargs,
check_package_version,
convert_to_secret_str,
get_pydantic_field_names,
guard_import,
mock_now,
raise_for_status_with_text,
xor_args,
)
from typing import Any
def __getattr__(name: str) -> Any:
if name == "image":
import langchain_core.utils.image
return langchain_core.utils.image
elif name == "get_from_dict_or_env":
from langchain_core.utils.env import get_from_dict_or_env
return get_from_dict_or_env
elif name == "get_from_env":
from langchain_core.utils.env import get_from_env
return get_from_env
elif name == "StrictFormatter":
from langchain_core.utils.formatting import StrictFormatter
return StrictFormatter
elif name == "formatter":
from langchain_core.utils.formatting import formatter
return formatter
elif name == "get_bolded_text":
from langchain_core.utils.input import get_bolded_text
return get_bolded_text
elif name == "get_color_mapping":
from langchain_core.utils.input import get_color_mapping
return get_color_mapping
elif name == "get_colored_text":
from langchain_core.utils.input import get_colored_text
return get_colored_text
elif name == "print_text":
from langchain_core.utils.input import print_text
return print_text
elif name == "try_load_from_hub":
from langchain_core.utils.loading import try_load_from_hub
return try_load_from_hub
elif name == "comma_list":
from langchain_core.utils.strings import comma_list
return comma_list
elif name == "stringify_dict":
from langchain_core.utils.strings import stringify_dict
return stringify_dict
elif name == "stringify_value":
from langchain_core.utils.strings import stringify_value
return stringify_value
elif name == "build_extra_kwargs":
from langchain_core.utils.utils import build_extra_kwargs
return build_extra_kwargs
elif name == "check_package_version":
from langchain_core.utils.utils import check_package_version
return check_package_version
elif name == "convert_to_secret_str":
from langchain_core.utils.utils import convert_to_secret_str
return convert_to_secret_str
elif name == "get_pydantic_field_names":
from langchain_core.utils.utils import get_pydantic_field_names
return get_pydantic_field_names
elif name == "guard_import":
from langchain_core.utils.utils import guard_import
return guard_import
elif name == "mock_now":
from langchain_core.utils.utils import mock_now
return mock_now
elif name == "raise_for_status_with_text":
from langchain_core.utils.utils import raise_for_status_with_text
return raise_for_status_with_text
elif name == "xor_args":
from langchain_core.utils.utils import xor_args
return xor_args
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = [
"StrictFormatter",

View File

@@ -1,17 +1,20 @@
"""Generic utility functions."""
import contextlib
import datetime
import functools
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Tuple, Union
from packaging.version import parse
from requests import HTTPError, Response
from langchain_core.pydantic_v1 import SecretStr
if TYPE_CHECKING:
from requests import Response
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
@@ -39,8 +42,11 @@ def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
return decorator
def raise_for_status_with_text(response: Response) -> None:
def raise_for_status_with_text(response: "Response") -> None:
"""Raise an error with the response text."""
from requests import HTTPError
try:
response.raise_for_status()
except HTTPError as e:

View File

@@ -1,4 +1,16 @@
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_mistralai.embeddings import MistralAIEmbeddings
from typing import Any
def __getattr__(name: str) -> Any:
if name == "ChatMistralAI":
from langchain_mistralai.chat_models import ChatMistralAI
return ChatMistralAI
elif name == "MistralAIEmbeddings":
from langchain_mistralai.embeddings import MistralAIEmbeddings
return MistralAIEmbeddings
raise AttributeError(f"module {__name__} has no attribute {name}")
__all__ = ["ChatMistralAI", "MistralAIEmbeddings"]

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
@@ -18,14 +19,7 @@ from typing import (
cast,
)
import httpx
from httpx_sse import EventSource, aconnect_sse, connect_sse
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@@ -57,6 +51,15 @@ from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
import httpx
from httpx_sse import EventSource
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
logger = logging.getLogger(__name__)
@@ -67,8 +70,9 @@ def _create_retry_decorator(
] = None,
) -> Callable[[Any], Any]:
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
from httpx import RequestError, StreamError
errors = [httpx.RequestError, httpx.StreamError]
errors = [RequestError, StreamError]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)
@@ -112,6 +116,8 @@ async def acompletion_with_retry(
kwargs["stream"] = False
stream = kwargs["stream"]
if stream:
from httpx_sse import aconnect_sse
event_source = aconnect_sse(
llm.async_client, "POST", "/chat/completions", json=kwargs
)
@@ -184,8 +190,8 @@ def _convert_message_to_mistral_chat_message(
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
_client: Optional[httpx.Client] = None
_async_client: Optional[httpx.AsyncClient] = None
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1"
max_retries: int = 5
@@ -202,6 +208,83 @@ class ChatMistralAI(BaseChatModel):
safe_mode: bool = False
streaming: bool = False
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
if "client" in values:
values["_client"] = values["client"]
del values["client"]
if "async_client" in values:
values["_async_client"] = values["async_client"]
del values["async_client"]
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
return values
@property
def client(self) -> httpx.Client:
"""Get the client."""
if self._client:
return self._client
from httpx import Client
# todo: handle retries
return Client(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
**(
{
"Authorization": (
f"Bearer {self.mistral_api_key.get_secret_value()}"
)
}
if self.mistral_api_key
else {}
),
},
timeout=self.timeout,
)
@property
def async_client(self) -> httpx.AsyncClient:
"""Get the async client."""
if self._async_client:
return self._async_client
# todo: handle retries and max concurrency
from httpx import AsyncClient
return AsyncClient(
base_url=self.endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
**(
{
"Authorization": (
f"Bearer {self.mistral_api_key.get_secret_value()}"
)
}
if self.mistral_api_key
else {}
),
},
timeout=self.timeout,
)
class Config:
"""Configuration for this pydantic object."""
@@ -241,6 +324,8 @@ class ChatMistralAI(BaseChatModel):
if stream:
def iter_sse() -> Iterator[Dict]:
from httpx_sse import connect_sse
with connect_sse(
self.client, "POST", "/chat/completions", json=kwargs
) as event_source:
@@ -272,45 +357,6 @@ class ChatMistralAI(BaseChatModel):
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
values["client"] = httpx.Client(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
# todo: handle retries and max_concurrency
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
return values
def _generate(
self,
messages: List[BaseMessage],

View File

@@ -0,0 +1,15 @@
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant."),
("human", "What is the capital of {country}?"),
]
)
chat = ChatMistralAI()
chain = prompt | chat
chain.invoke({"country": "Denmark"})

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
[tool.poetry]
name = "chatmistralai"
version = "0.1.0"
description = ""
authors = ["Erick Friis <erick@langchain.dev>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
langchain-mistralai = {path = "../../../libs/partners/mistralai", develop = true}
langchain-core = {path = "../../../libs/core", develop = true}
memray = "^1.11.0"
langsmith = {path = "/Users/erickfriis/langchain/smith-sdk/python", develop = true}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,15 @@
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant."),
("human", "What is the capital of {country}?"),
]
)
chat = ChatOpenAI()
chain = prompt | chat
chain.invoke({"country": "Denmark"})

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
[tool.poetry]
name = "chatopenai"
version = "0.1.0"
description = ""
authors = ["Erick Friis <erick@langchain.dev>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
langchain-core = {path = "../../../libs/core", develop = true}
langchain-openai = {path = "../../../libs/partners/openai", develop = true}
memray = "^1.11.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"