mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
add rules and fixes
This commit is contained in:
parent
8064d3bdc4
commit
1be308db81
@ -1,4 +1,4 @@
|
|||||||
"""This package provides the Perplexity integration for LangChain."""
|
"""Provides the Perplexity integration for LangChain."""
|
||||||
|
|
||||||
from langchain_perplexity.chat_models import ChatPerplexity
|
from langchain_perplexity.chat_models import ChatPerplexity
|
||||||
|
|
||||||
|
@ -183,7 +183,8 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
extra = values.get("model_kwargs", {})
|
extra = values.get("model_kwargs", {})
|
||||||
for field_name in list(values):
|
for field_name in list(values):
|
||||||
if field_name in extra:
|
if field_name in extra:
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
msg = f"Found {field_name} supplied twice."
|
||||||
|
raise ValueError(msg)
|
||||||
if field_name not in all_required_field_names:
|
if field_name not in all_required_field_names:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"""WARNING! {field_name} is not a default parameter.
|
f"""WARNING! {field_name} is not a default parameter.
|
||||||
@ -194,10 +195,11 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
|
|
||||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
if invalid_model_kwargs:
|
if invalid_model_kwargs:
|
||||||
raise ValueError(
|
msg = (
|
||||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
return values
|
return values
|
||||||
@ -213,11 +215,12 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
base_url="https://api.perplexity.ai",
|
base_url="https://api.perplexity.ai",
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
"due to an old version of the openai package. Try upgrading it "
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
"with `pip install --upgrade openai`."
|
"with `pip install --upgrade openai`."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg) from None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -240,7 +243,8 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
message_dict = {"role": "assistant", "content": message.content}
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Got unknown type {message}")
|
msg = f"Got unknown type {message}"
|
||||||
|
raise TypeError(msg)
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def _create_message_dicts(
|
def _create_message_dicts(
|
||||||
@ -249,7 +253,8 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
params = dict(self._invocation_params)
|
params = dict(self._invocation_params)
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
if "stop" in params:
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
msg = "`stop` found in both the input and default params."
|
||||||
|
raise ValueError(msg)
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||||
return message_dicts, params
|
return message_dicts, params
|
||||||
@ -270,18 +275,17 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role == "function" or default_class == FunctionMessageChunk:
|
if role == "function" or default_class == FunctionMessageChunk:
|
||||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
elif role == "tool" or default_class == ToolMessageChunk:
|
if role == "tool" or default_class == ToolMessageChunk:
|
||||||
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||||
elif role or default_class == ChatMessageChunk:
|
if role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||||
else:
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
return default_class(content=content) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
@ -409,8 +413,9 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
|
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
|
||||||
|
|
||||||
Currently, Perplexity only supports "json_schema" method for structured output
|
Currently, Perplexity only supports "json_schema" method for structured output
|
||||||
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs
|
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema:
|
schema:
|
||||||
@ -456,10 +461,11 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
method = "json_schema"
|
method = "json_schema"
|
||||||
if method == "json_schema":
|
if method == "json_schema":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is not 'json_schema'. "
|
"schema must be specified when method is not 'json_schema'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
is_pydantic_schema = _is_pydantic_class(schema)
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
response_format = convert_to_json_schema(schema)
|
response_format = convert_to_json_schema(schema)
|
||||||
llm = self.bind(
|
llm = self.bind(
|
||||||
@ -478,10 +484,9 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
else JsonOutputParser()
|
else JsonOutputParser()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
msg = f"Unrecognized method argument. Expected 'json_schema' Received:\
|
||||||
f"Unrecognized method argument. Expected 'json_schema' Received:\
|
|
||||||
'{method}'"
|
'{method}'"
|
||||||
)
|
raise ValueError(msg)
|
||||||
|
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(
|
parser_assign = RunnablePassthrough.assign(
|
||||||
@ -492,5 +497,4 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
)
|
)
|
||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
else:
|
return llm | output_parser
|
||||||
return llm | output_parser
|
|
||||||
|
@ -59,8 +59,63 @@ ignore_missing_imports = true
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
select = [
|
||||||
ignore = [ "UP007", ]
|
"A", # flake8-builtins
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"ASYNC", # flake8-async
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
|
"DOC", # pydoclint
|
||||||
|
"E", # pycodestyle error
|
||||||
|
"EM", # flake8-errmsg
|
||||||
|
"F", # pyflakes
|
||||||
|
"FA", # flake8-future-annotations
|
||||||
|
"FBT", # flake8-boolean-trap
|
||||||
|
"FLY", # flake8-flynt
|
||||||
|
"I", # isort
|
||||||
|
"ICN", # flake8-import-conventions
|
||||||
|
"INT", # flake8-gettext
|
||||||
|
"ISC", # isort-comprehensions
|
||||||
|
"PGH", # pygrep-hooks
|
||||||
|
"PIE", # flake8-pie
|
||||||
|
"PERF", # flake8-perf
|
||||||
|
"PYI", # flake8-pyi
|
||||||
|
"Q", # flake8-quotes
|
||||||
|
"RET", # flake8-return
|
||||||
|
"RSE", # flake8-rst-docstrings
|
||||||
|
"RUF", # ruff
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"SLF", # flake8-self
|
||||||
|
"SLOT", # flake8-slots
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"T10", # flake8-debugger
|
||||||
|
"T20", # flake8-print
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warning
|
||||||
|
"YTT", # flake8-2020
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D203", # Messes with the formatter
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"S112", # Rarely useful
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"SLF001", # Private member access
|
||||||
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
"UP045", # pyupgrade: non-pep604-annotation-optional
|
||||||
|
]
|
||||||
|
unfixable = ["B028"] # People should intentionally tune the stacklevel
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
|
@ -4,4 +4,3 @@ import pytest # type: ignore[import-not-found]
|
|||||||
@pytest.mark.compile
|
@pytest.mark.compile
|
||||||
def test_placeholder() -> None:
|
def test_placeholder() -> None:
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
pass
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user