mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 13:06:03 +00:00
Merge branch 'langchain-ai:master' into master
This commit is contained in:
@@ -29,6 +29,15 @@ class SQLDatabaseChain(Chain):
|
||||
from langchain import OpenAI, SQLDatabase
|
||||
db = SQLDatabase(...)
|
||||
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include the permissions this chain needs.
|
||||
Failure to do so may result in data corruption or loss, since this chain may
|
||||
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this chain.
|
||||
This issue shows an example negative outcome if these steps are not taken:
|
||||
https://github.com/langchain-ai/langchain/issues/5923
|
||||
"""
|
||||
|
||||
llm_chain: LLMChain
|
||||
@@ -49,7 +58,7 @@ class SQLDatabaseChain(Chain):
|
||||
return_direct: bool = False
|
||||
"""Whether or not to return the result of querying the SQL table directly."""
|
||||
use_query_checker: bool = False
|
||||
"""Whether or not the query checker tool should be used to attempt
|
||||
"""Whether or not the query checker tool should be used to attempt
|
||||
to fix the initial SQL from the LLM."""
|
||||
query_checker_prompt: Optional[BasePromptTemplate] = None
|
||||
"""The prompt template that should be used by the query checker"""
|
||||
@@ -197,6 +206,17 @@ class SQLDatabaseChain(Chain):
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any,
|
||||
) -> SQLDatabaseChain:
|
||||
"""Create a SQLDatabaseChain from an LLM and a database connection.
|
||||
|
||||
*Security note*: Make sure that the database connection uses credentials
|
||||
that are narrowly-scoped to only include the permissions this chain needs.
|
||||
Failure to do so may result in data corruption or loss, since this chain may
|
||||
attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted.
|
||||
The best way to guard against such negative outcomes is to (as appropriate)
|
||||
limit the permissions granted to the credentials used with this chain.
|
||||
This issue shows an example negative outcome if these steps are not taken:
|
||||
https://github.com/langchain-ai/langchain/issues/5923
|
||||
"""
|
||||
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(llm_chain=llm_chain, database=db, **kwargs)
|
||||
|
83
libs/experimental/tests/unit_tests/conftest.py
Normal file
83
libs/experimental/tests/unit_tests/conftest.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Configuration for unit tests."""
|
||||
from importlib import util
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import pytest
|
||||
from pytest import Config, Function, Parser
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser) -> None:
|
||||
"""Add custom command line options to pytest."""
|
||||
parser.addoption(
|
||||
"--only-extended",
|
||||
action="store_true",
|
||||
help="Only run extended tests. Does not allow skipping any extended tests.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--only-core",
|
||||
action="store_true",
|
||||
help="Only run core tests. Never runs any extended tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
|
||||
"""Add implementations for handling custom markers.
|
||||
|
||||
At the moment, this adds support for a custom `requires` marker.
|
||||
|
||||
The `requires` marker is used to denote tests that require one or more packages
|
||||
to be installed to run. If the package is not installed, the test is skipped.
|
||||
|
||||
The `requires` marker syntax is:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@pytest.mark.requires("package1", "package2")
|
||||
def test_something():
|
||||
...
|
||||
"""
|
||||
# Mapping from the name of a package to whether it is installed or not.
|
||||
# Used to avoid repeated calls to `util.find_spec`
|
||||
required_pkgs_info: Dict[str, bool] = {}
|
||||
|
||||
only_extended = config.getoption("--only-extended") or False
|
||||
only_core = config.getoption("--only-core") or False
|
||||
|
||||
if only_extended and only_core:
|
||||
raise ValueError("Cannot specify both `--only-extended` and `--only-core`.")
|
||||
|
||||
for item in items:
|
||||
requires_marker = item.get_closest_marker("requires")
|
||||
if requires_marker is not None:
|
||||
if only_core:
|
||||
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
|
||||
continue
|
||||
|
||||
# Iterate through the list of required packages
|
||||
required_pkgs = requires_marker.args
|
||||
for pkg in required_pkgs:
|
||||
# If we haven't yet checked whether the pkg is installed
|
||||
# let's check it and store the result.
|
||||
if pkg not in required_pkgs_info:
|
||||
required_pkgs_info[pkg] = util.find_spec(pkg) is not None
|
||||
|
||||
if not required_pkgs_info[pkg]:
|
||||
if only_extended:
|
||||
pytest.fail(
|
||||
f"Package `{pkg}` is not installed but is required for "
|
||||
f"extended tests. Please install the given package and "
|
||||
f"try again.",
|
||||
)
|
||||
|
||||
else:
|
||||
# If the package is not installed, we immediately break
|
||||
# and mark the test as skipped.
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`")
|
||||
)
|
||||
break
|
||||
else:
|
||||
if only_extended:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="Skipping not an extended test.")
|
||||
)
|
@@ -20,6 +20,7 @@ from langchain.callbacks.human import HumanApprovalCallbackHandler
|
||||
from langchain.callbacks.infino_callback import InfinoCallbackHandler
|
||||
from langchain.callbacks.labelstudio_callback import LabelStudioCallbackHandler
|
||||
from langchain.callbacks.manager import (
|
||||
collect_runs,
|
||||
get_openai_callback,
|
||||
tracing_enabled,
|
||||
tracing_v2_enabled,
|
||||
@@ -66,6 +67,7 @@ __all__ = [
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
|
@@ -38,6 +38,7 @@ from langchain.callbacks.base import (
|
||||
)
|
||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers import run_collector
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
@@ -75,6 +76,11 @@ tracing_v2_callback_var: ContextVar[
|
||||
] = ContextVar( # noqa: E501
|
||||
"tracing_callback_v2", default=None
|
||||
)
|
||||
run_collector_var: ContextVar[
|
||||
Optional[run_collector.RunCollectorCallbackHandler]
|
||||
] = ContextVar( # noqa: E501
|
||||
"run_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def _get_debug() -> bool:
|
||||
@@ -184,6 +190,24 @@ def tracing_v2_enabled(
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
|
||||
"""Collect all run traces in context.
|
||||
|
||||
Returns:
|
||||
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
|
||||
|
||||
Example:
|
||||
>>> with collect_runs() as runs_cb:
|
||||
chain.invoke("foo")
|
||||
run_id = runs_cb.traced_runs[0].id
|
||||
"""
|
||||
cb = run_collector.RunCollectorCallbackHandler()
|
||||
run_collector_var.set(cb)
|
||||
yield cb
|
||||
run_collector_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_as_chain_group(
|
||||
group_name: str,
|
||||
@@ -1712,6 +1736,7 @@ def _configure(
|
||||
tracer_project = os.environ.get(
|
||||
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
||||
)
|
||||
run_collector_ = run_collector_var.get()
|
||||
debug = _get_debug()
|
||||
if (
|
||||
verbose
|
||||
@@ -1774,4 +1799,6 @@ def _configure(
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(open_ai, True)
|
||||
if run_collector_ is not None:
|
||||
callback_manager.add_handler(run_collector_, False)
|
||||
return callback_manager
|
||||
|
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, wait
|
||||
from typing import Any, List, Optional, Sequence, Set, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Union
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client, RunEvaluator
|
||||
import langsmith
|
||||
from langsmith import schemas as langsmith_schemas
|
||||
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
@@ -62,13 +63,13 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
The LangSmith project name to be organize eval chain runs under.
|
||||
"""
|
||||
|
||||
name: str = "evaluator_callback_handler"
|
||||
name = "evaluator_callback_handler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
evaluators: Sequence[RunEvaluator],
|
||||
evaluators: Sequence[langsmith.RunEvaluator],
|
||||
max_workers: Optional[int] = None,
|
||||
client: Optional[Client] = None,
|
||||
client: Optional[langsmith.Client] = None,
|
||||
example_id: Optional[Union[UUID, str]] = None,
|
||||
skip_unfinished: bool = True,
|
||||
project_name: Optional[str] = "evaluators",
|
||||
@@ -86,10 +87,11 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
self.futures: Set[Future] = set()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
|
||||
global _TRACERS
|
||||
_TRACERS.append(self)
|
||||
|
||||
def _evaluate_in_project(self, run: Run, evaluator: RunEvaluator) -> None:
|
||||
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
|
||||
"""Evaluate the run in the project.
|
||||
|
||||
Parameters
|
||||
@@ -102,11 +104,11 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
"""
|
||||
try:
|
||||
if self.project_name is None:
|
||||
self.client.evaluate_run(run, evaluator)
|
||||
feedback = self.client.evaluate_run(run, evaluator)
|
||||
with tracing_v2_enabled(
|
||||
project_name=self.project_name, tags=["eval"], client=self.client
|
||||
):
|
||||
self.client.evaluate_run(run, evaluator)
|
||||
feedback = self.client.evaluate_run(run, evaluator)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error evaluating run {run.id} with "
|
||||
@@ -114,6 +116,8 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
example_id = str(run.reference_example_id)
|
||||
self.logged_feedback.setdefault(example_id, []).append(feedback)
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Run the evaluator on the run.
|
||||
|
117
libs/langchain/langchain/chat_loaders/imessage.py
Normal file
117
libs/langchain/langchain/chat_loaders/imessage.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""IMessage Chat Loader.
|
||||
|
||||
This class is used to load chat sessions from the iMessage chat.db SQLite file.
|
||||
It only works on macOS when you have iMessage enabled and have the chat.db file.
|
||||
|
||||
The chat.db file is likely located at ~/Library/Messages/chat.db. However, your
|
||||
terminal may not have permission to access this file. To resolve this, you can
|
||||
copy the file to a different location, change the permissions of the file, or
|
||||
grant full disk access for your terminal emulator in System Settings > Security
|
||||
and Privacy > Full Disk Access.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
|
||||
|
||||
from langchain import schema
|
||||
from langchain.chat_loaders import base as chat_loaders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlite3
|
||||
|
||||
|
||||
class IMessageChatLoader(chat_loaders.BaseChatLoader):
|
||||
def __init__(self, path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
Initialize the IMessageChatLoader.
|
||||
|
||||
Args:
|
||||
path (str or Path, optional): Path to the chat.db SQLite file.
|
||||
Defaults to None, in which case the default path
|
||||
~/Library/Messages/chat.db will be used.
|
||||
"""
|
||||
if path is None:
|
||||
path = Path.home() / "Library" / "Messages" / "chat.db"
|
||||
self.db_path = path if isinstance(path, Path) else Path(path)
|
||||
if not self.db_path.exists():
|
||||
raise FileNotFoundError(f"File {self.db_path} not found")
|
||||
try:
|
||||
pass # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The sqlite3 module is required to load iMessage chats.\n"
|
||||
"Please install it with `pip install pysqlite3`"
|
||||
) from e
|
||||
|
||||
def _load_single_chat_session(
|
||||
self, cursor: "sqlite3.Cursor", chat_id: int
|
||||
) -> chat_loaders.ChatSession:
|
||||
"""
|
||||
Load a single chat session from the iMessage chat.db.
|
||||
|
||||
Args:
|
||||
cursor: SQLite cursor object.
|
||||
chat_id (int): ID of the chat session to load.
|
||||
|
||||
Returns:
|
||||
ChatSession: Loaded chat session.
|
||||
"""
|
||||
results: List[schema.HumanMessage] = []
|
||||
|
||||
query = """
|
||||
SELECT message.date, handle.id, message.text
|
||||
FROM message
|
||||
JOIN chat_message_join ON message.ROWID = chat_message_join.message_id
|
||||
JOIN handle ON message.handle_id = handle.ROWID
|
||||
WHERE chat_message_join.chat_id = ?
|
||||
ORDER BY message.date ASC;
|
||||
"""
|
||||
cursor.execute(query, (chat_id,))
|
||||
messages = cursor.fetchall()
|
||||
|
||||
for date, sender, text in messages:
|
||||
if text: # Skip empty messages
|
||||
results.append(
|
||||
schema.HumanMessage(
|
||||
role=sender,
|
||||
content=text,
|
||||
additional_kwargs={
|
||||
"message_time": date,
|
||||
"sender": sender,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return chat_loaders.ChatSession(messages=results)
|
||||
|
||||
def lazy_load(self) -> Iterator[chat_loaders.ChatSession]:
|
||||
"""
|
||||
Lazy load the chat sessions from the iMessage chat.db
|
||||
and yield them in the required format.
|
||||
|
||||
Yields:
|
||||
ChatSession: Loaded chat session.
|
||||
"""
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
except sqlite3.OperationalError as e:
|
||||
raise ValueError(
|
||||
f"Could not open iMessage DB file {self.db_path}.\n"
|
||||
"Make sure your terminal emulator has disk access to this file.\n"
|
||||
" You can either copy the DB file to an accessible location"
|
||||
" or grant full disk access for your terminal emulator."
|
||||
" You can grant full disk access for your terminal emulator"
|
||||
" in System Settings > Security and Privacy > Full Disk Access."
|
||||
) from e
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Fetch the list of chat IDs
|
||||
cursor.execute("SELECT ROWID FROM chat")
|
||||
chat_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
for chat_id in chat_ids:
|
||||
yield self._load_single_chat_session(cursor, chat_id)
|
||||
|
||||
conn.close()
|
29
libs/langchain/langchain/llms/grammars/json.gbnf
Normal file
29
libs/langchain/langchain/llms/grammars/json.gbnf
Normal file
@@ -0,0 +1,29 @@
|
||||
# Grammar for subset of JSON - doesn't support full string or number syntax
|
||||
|
||||
root ::= object
|
||||
value ::= object | array | string | number | boolean | "null"
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}"
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]"
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
# Only plain integers currently
|
||||
number ::= "-"? [0-9]+ ws
|
||||
boolean ::= ("true" | "false") ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
14
libs/langchain/langchain/llms/grammars/list.gbnf
Normal file
14
libs/langchain/langchain/llms/grammars/list.gbnf
Normal file
@@ -0,0 +1,14 @@
|
||||
root ::= "[" items "]" EOF
|
||||
|
||||
items ::= item ("," ws* item)*
|
||||
|
||||
item ::= string
|
||||
|
||||
string ::=
|
||||
"\"" word (ws+ word)* "\"" ws*
|
||||
|
||||
word ::= [a-zA-Z]+
|
||||
|
||||
ws ::= " "
|
||||
|
||||
EOF ::= "\n"
|
@@ -1,5 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
@@ -8,6 +11,9 @@ from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_pydantic_field_names
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp import LlamaGrammar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -113,12 +119,35 @@ class LlamaCpp(LLM):
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
grammar_path: Optional[Union[str, Path]] = None
|
||||
"""
|
||||
grammar_path: Path to the .gbnf file that defines formal grammars
|
||||
for constraining model outputs. For instance, the grammar can be used
|
||||
to force the model to generate valid JSON or to speak exclusively in emojis. At most
|
||||
one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
grammar: Optional[Union[str, LlamaGrammar]] = None
|
||||
"""
|
||||
grammar: formal grammar for constraining model outputs. For instance, the grammar
|
||||
can be used to force the model to generate valid JSON or to speak exclusively in
|
||||
emojis. At most one of grammar_path and grammar should be passed in.
|
||||
"""
|
||||
|
||||
verbose: bool = True
|
||||
"""Print verbose output to stderr."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
try:
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
|
||||
model_path = values["model_path"]
|
||||
model_param_names = [
|
||||
"rope_freq_scale",
|
||||
@@ -146,21 +175,26 @@ class LlamaCpp(LLM):
|
||||
model_params.update(values["model_kwargs"])
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
if values["grammar"] and values["grammar_path"]:
|
||||
grammar = values["grammar"]
|
||||
grammar_path = values["grammar_path"]
|
||||
raise ValueError(
|
||||
"Can only pass in one of grammar and grammar_path. Received "
|
||||
f"{grammar=} and {grammar_path=}."
|
||||
)
|
||||
elif isinstance(values["grammar"], str):
|
||||
values["grammar"] = LlamaGrammar.from_string(values["grammar"])
|
||||
elif values["grammar_path"]:
|
||||
values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
|
||||
else:
|
||||
pass
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
@@ -176,7 +210,7 @@ class LlamaCpp(LLM):
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling llama_cpp."""
|
||||
return {
|
||||
params = {
|
||||
"suffix": self.suffix,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
@@ -187,6 +221,9 @@ class LlamaCpp(LLM):
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
if self.grammar:
|
||||
params["grammar"] = self.grammar
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
@@ -252,7 +289,10 @@ class LlamaCpp(LLM):
|
||||
# and return the combined strings from the first choices's text:
|
||||
combined_text_output = ""
|
||||
for chunk in self._stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
prompt=prompt,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
):
|
||||
combined_text_output += chunk.text
|
||||
return combined_text_output
|
||||
|
@@ -2,7 +2,7 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -26,10 +26,7 @@ def load_prompt_from_config(config: dict) -> BasePromptTemplate:
|
||||
raise ValueError(f"Loading {config_type} prompt not supported")
|
||||
|
||||
prompt_loader = type_to_loader_dict[config_type]
|
||||
# Unclear why type error is being thrown here.
|
||||
# Incompatible return value type (got "Runnable[Dict[Any, Any], PromptValue]",
|
||||
# expected "BasePromptTemplate") [return-value]
|
||||
return prompt_loader(config) # type: ignore[return-value]
|
||||
return prompt_loader(config)
|
||||
|
||||
|
||||
def _load_template(var_name: str, config: dict) -> dict:
|
||||
@@ -148,8 +145,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
return load_prompt_from_config(config)
|
||||
|
||||
|
||||
type_to_loader_dict = {
|
||||
type_to_loader_dict: Dict[str, Callable[[dict], BasePromptTemplate]] = {
|
||||
"prompt": _load_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ import uuid
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
@@ -44,6 +45,9 @@ from langchain.schema.runnable import Runnable, RunnableConfig, RunnableLambda
|
||||
from langchain.smith.evaluation.config import EvalConfig, RunEvalConfig
|
||||
from langchain.smith.evaluation.string_run_evaluator import StringRunEvaluatorChain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_OR_CHAIN_FACTORY = Union[
|
||||
@@ -63,6 +67,31 @@ class InputFormatError(Exception):
|
||||
## Shared Utilities
|
||||
|
||||
|
||||
class TestResult(dict):
|
||||
"""A dictionary of the results of a single test run."""
|
||||
|
||||
def to_dataframe(self) -> pd.DataFrame:
|
||||
"""Convert the results to a dataframe."""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Pandas is required to convert the results to a dataframe."
|
||||
" to install pandas, run `pip install pandas`."
|
||||
) from e
|
||||
|
||||
indices = []
|
||||
records = []
|
||||
for example_id, result in self["results"].items():
|
||||
feedback = result["feedback"]
|
||||
records.append(
|
||||
{**{f.key: f.score for f in feedback}, "output": result["output"]}
|
||||
)
|
||||
indices.append(example_id)
|
||||
|
||||
return pd.DataFrame(records, index=indices)
|
||||
|
||||
|
||||
def _get_eval_project_url(api_url: str, project_id: str) -> str:
|
||||
"""Get the project url from the api url."""
|
||||
parsed = urlparse(api_url)
|
||||
@@ -667,7 +696,7 @@ async def _arun_llm_or_chain(
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""Asynchronously run the Chain or language model.
|
||||
|
||||
Args:
|
||||
@@ -689,10 +718,10 @@ async def _arun_llm_or_chain(
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_ids = None
|
||||
outputs = []
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
result = None
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = await _arun_llm(
|
||||
@@ -711,15 +740,15 @@ async def _arun_llm_or_chain(
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
result = output
|
||||
except Exception as e:
|
||||
logger.warning(f"{chain_or_llm} failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
result = {"Error": str(e)}
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
return outputs
|
||||
return result
|
||||
|
||||
|
||||
async def _gather_with_concurrency(
|
||||
@@ -856,7 +885,7 @@ async def _arun_on_examples(
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
results: Dict[str, List[Any]] = {}
|
||||
results: Dict[str, dict] = {}
|
||||
|
||||
async def process_example(
|
||||
example: Example, callbacks: List[BaseCallbackHandler], job_state: dict
|
||||
@@ -869,7 +898,7 @@ async def _arun_on_examples(
|
||||
callbacks=callbacks,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
results[str(example.id)] = result
|
||||
results[str(example.id)] = {"output": result}
|
||||
job_state["num_processed"] += 1
|
||||
if verbose:
|
||||
print(
|
||||
@@ -890,8 +919,14 @@ async def _arun_on_examples(
|
||||
),
|
||||
*(functools.partial(process_example, e) for e in examples),
|
||||
)
|
||||
all_feedback = {}
|
||||
for handler in evaluation_handlers:
|
||||
handler.wait_for_futures()
|
||||
all_feedback.update(handler.logged_feedback)
|
||||
# join the results and feedback on the example id
|
||||
for example_id, output_dict in results.items():
|
||||
feedback = all_feedback.get(example_id, [])
|
||||
output_dict["feedback"] = feedback
|
||||
return results
|
||||
|
||||
|
||||
@@ -978,7 +1013,7 @@ def _run_llm_or_chain(
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None,
|
||||
input_mapper: Optional[Callable[[Dict], Any]] = None,
|
||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||
) -> Union[dict, str, LLMResult, ChatResult]:
|
||||
"""
|
||||
Run the Chain or language model synchronously.
|
||||
|
||||
@@ -1001,10 +1036,10 @@ def _run_llm_or_chain(
|
||||
tracer.example_id = example.id
|
||||
else:
|
||||
previous_example_ids = None
|
||||
outputs = []
|
||||
chain_or_llm = (
|
||||
"LLM" if isinstance(llm_or_chain_factory, BaseLanguageModel) else "Chain"
|
||||
)
|
||||
result = None
|
||||
try:
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
output: Any = _run_llm(
|
||||
@@ -1023,18 +1058,18 @@ def _run_llm_or_chain(
|
||||
tags=tags,
|
||||
input_mapper=input_mapper,
|
||||
)
|
||||
outputs.append(output)
|
||||
result = output
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"{chain_or_llm} failed for example {example.id} with inputs:"
|
||||
f" {example.inputs}.\nError: {e}",
|
||||
)
|
||||
outputs.append({"Error": str(e)})
|
||||
result = {"Error": str(e)}
|
||||
if callbacks and previous_example_ids:
|
||||
for example_id, tracer in zip(previous_example_ids, callbacks):
|
||||
if hasattr(tracer, "example_id"):
|
||||
tracer.example_id = example_id
|
||||
return outputs
|
||||
return result
|
||||
|
||||
|
||||
def _run_on_examples(
|
||||
@@ -1075,7 +1110,7 @@ def _run_on_examples(
|
||||
Returns:
|
||||
A dictionary mapping example ids to the model outputs.
|
||||
"""
|
||||
results: Dict[str, Any] = {}
|
||||
results: Dict[str, dict] = {}
|
||||
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory)
|
||||
project_name = _get_project_name(project_name, wrapped_model)
|
||||
tracer = LangChainTracer(
|
||||
@@ -1085,11 +1120,11 @@ def _run_on_examples(
|
||||
wrapped_model, examples, evaluation, data_type
|
||||
)
|
||||
examples = _validate_example_inputs(examples, wrapped_model, input_mapper)
|
||||
evalution_handler = EvaluatorCallbackHandler(
|
||||
evaluation_handler = EvaluatorCallbackHandler(
|
||||
evaluators=run_evaluators or [],
|
||||
client=client,
|
||||
)
|
||||
callbacks: List[BaseCallbackHandler] = [tracer, evalution_handler]
|
||||
callbacks: List[BaseCallbackHandler] = [tracer, evaluation_handler]
|
||||
for i, example in enumerate(examples):
|
||||
result = _run_llm_or_chain(
|
||||
example,
|
||||
@@ -1100,9 +1135,14 @@ def _run_on_examples(
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
results[str(example.id)] = {"output": result}
|
||||
tracer.wait_for_futures()
|
||||
evalution_handler.wait_for_futures()
|
||||
evaluation_handler.wait_for_futures()
|
||||
all_feedback = evaluation_handler.logged_feedback
|
||||
# join the results and feedback on the example id
|
||||
for example_id, output_dict in results.items():
|
||||
feedback = all_feedback.get(example_id, [])
|
||||
output_dict["feedback"] = feedback
|
||||
return results
|
||||
|
||||
|
||||
@@ -1276,10 +1316,10 @@ async def arun_on_dataset(
|
||||
input_mapper=input_mapper,
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
"results": results,
|
||||
}
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
def _handle_coroutine(coro: Coroutine) -> Any:
|
||||
@@ -1461,7 +1501,7 @@ def run_on_dataset(
|
||||
data_type=dataset.data_type,
|
||||
)
|
||||
results = _handle_coroutine(coro)
|
||||
return {
|
||||
"project_name": project_name,
|
||||
"results": results,
|
||||
}
|
||||
return TestResult(
|
||||
project_name=project_name,
|
||||
results=results,
|
||||
)
|
||||
|
@@ -19,7 +19,9 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
if isinstance(tool, StructuredTool):
|
||||
schema_ = tool.args_schema.schema()
|
||||
# Bug with required missing for structured tools.
|
||||
required = sorted(schema_["properties"]) # BUG WORKAROUND
|
||||
required = schema_.get(
|
||||
"required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND
|
||||
)
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
|
@@ -1298,7 +1298,7 @@ class Qdrant(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
|
||||
"""
|
||||
qdrant = cls._construct_instance(
|
||||
qdrant = cls.construct_instance(
|
||||
texts,
|
||||
embedding,
|
||||
location,
|
||||
@@ -1474,7 +1474,7 @@ class Qdrant(VectorStore):
|
||||
embeddings = OpenAIEmbeddings()
|
||||
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
|
||||
"""
|
||||
qdrant = await cls._aconstruct_instance(
|
||||
qdrant = await cls.aconstruct_instance(
|
||||
texts,
|
||||
embedding,
|
||||
location,
|
||||
@@ -1510,7 +1510,7 @@ class Qdrant(VectorStore):
|
||||
return qdrant
|
||||
|
||||
@classmethod
|
||||
def _construct_instance(
|
||||
def construct_instance(
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
@@ -1676,7 +1676,7 @@ class Qdrant(VectorStore):
|
||||
return qdrant
|
||||
|
||||
@classmethod
|
||||
async def _aconstruct_instance(
|
||||
async def aconstruct_instance(
|
||||
cls: Type[Qdrant],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain"
|
||||
version = "0.0.274"
|
||||
version = "0.0.275"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
|
@@ -0,0 +1,16 @@
|
||||
"""Test the run collector."""
|
||||
|
||||
import uuid
|
||||
|
||||
from langchain.callbacks import collect_runs
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_collect_runs() -> None:
|
||||
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
|
||||
with collect_runs() as cb:
|
||||
llm.predict("hi")
|
||||
assert cb.traced_runs
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
|
@@ -182,14 +182,12 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
return {"the right input": inputs["the wrong input"]}
|
||||
|
||||
result = _run_llm_or_chain(example, lambda: mock_chain, input_mapper=input_mapper)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"output": "2", "the right input": "1"}
|
||||
assert result == {"output": "2", "the right input": "1"}
|
||||
bad_result = _run_llm_or_chain(
|
||||
example,
|
||||
lambda: mock_chain,
|
||||
)
|
||||
assert len(bad_result) == 1
|
||||
assert "Error" in bad_result[0]
|
||||
assert "Error" in bad_result
|
||||
|
||||
# Try with LLM
|
||||
def llm_input_mapper(inputs: dict) -> str:
|
||||
@@ -197,9 +195,7 @@ def test_run_llm_or_chain_with_input_mapper() -> None:
|
||||
return "the right input"
|
||||
|
||||
mock_llm = FakeLLM(queries={"the right input": "somenumber"})
|
||||
result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
|
||||
assert len(result) == 1
|
||||
llm_result = result[0]
|
||||
llm_result = _run_llm_or_chain(example, mock_llm, input_mapper=llm_input_mapper)
|
||||
assert isinstance(llm_result, str)
|
||||
assert llm_result == "somenumber"
|
||||
|
||||
@@ -300,8 +296,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
tags: Optional[List[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return [{"result": f"Result for example {example.id}"}]
|
||||
) -> Dict[str, Any]:
|
||||
return {"result": f"Result for example {example.id}"}
|
||||
|
||||
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
|
||||
proj = mock.MagicMock()
|
||||
@@ -328,9 +324,10 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
|
||||
expected = {
|
||||
uuid_: [
|
||||
{"result": f"Result for example {uuid.UUID(uuid_)}"} for _ in range(1)
|
||||
]
|
||||
uuid_: {
|
||||
"output": {"result": f"Result for example {uuid.UUID(uuid_)}"},
|
||||
"feedback": [],
|
||||
}
|
||||
for uuid_ in uuids
|
||||
}
|
||||
assert results["results"] == expected
|
||||
|
Reference in New Issue
Block a user