Compare commits

...

2 Commits

Author SHA1 Message Date
Maddy Adams
d439403dfa fix 2024-06-16 19:26:51 -07:00
Maddy Adams
a1526aeb3c feat: support passing options to hub.pull 2024-06-16 13:19:11 -07:00

View File

@@ -10,7 +10,7 @@ from langchain_core.load.load import loads
from langchain_core.prompts import BasePromptTemplate
if TYPE_CHECKING:
from langchainhub import Client
from langchainhub import Client, _types
def _get_client(api_url: Optional[str] = None, api_key: Optional[str] = None) -> Client:
@@ -26,6 +26,14 @@ def _get_client(api_url: Optional[str] = None, api_key: Optional[str] = None) ->
return Client(api_url, api_key=api_key)
def _add_metadata(obj: Any, repo_dict: _types.Repo) -> Any:
if obj.metadata is None:
obj.metadata = {}
obj.metadata["lc_hub_owner"] = repo_dict["owner"]
obj.metadata["lc_hub_repo"] = repo_dict["repo"]
obj.metadata["lc_hub_commit_hash"] = repo_dict["commit_hash"]
def push(
repo_full_name: str,
object: Any,
@@ -33,7 +41,7 @@ def push(
api_url: Optional[str] = None,
api_key: Optional[str] = None,
parent_commit_hash: Optional[str] = "latest",
new_repo_is_public: bool = True,
new_repo_is_public: bool = False,
new_repo_description: str = "",
) -> str:
"""
@@ -48,7 +56,7 @@ def push(
:param parent_commit_hash: The commit hash of the parent commit to push to. Defaults
to the latest commit automatically.
:param new_repo_is_public: Whether the repo should be public. Defaults to
True (Public by default).
False (Private by default).
:param new_repo_description: The description of the repo. Defaults to an empty
string.
"""
@@ -69,6 +77,7 @@ def pull(
*,
api_url: Optional[str] = None,
api_key: Optional[str] = None,
options: Optional[dict[str, Any]] = None,
) -> Any:
"""
Pull an object from the hub and returns it as a LangChain object.
@@ -83,14 +92,20 @@ def pull(
if hasattr(client, "pull_repo"):
# >= 0.1.15
res_dict = client.pull_repo(owner_repo_commit)
res_dict = {}
try:
res_dict = client.pull_repo(owner_repo_commit, options=options)
except Exception:
try:
res_dict = client.pull_repo(owner_repo_commit)
except Exception as e:
raise e
obj = loads(json.dumps(res_dict["manifest"]))
if isinstance(obj, BasePromptTemplate):
if obj.metadata is None:
obj.metadata = {}
obj.metadata["lc_hub_owner"] = res_dict["owner"]
obj.metadata["lc_hub_repo"] = res_dict["repo"]
obj.metadata["lc_hub_commit_hash"] = res_dict["commit_hash"]
_add_metadata(obj=obj, repo_dict=res_dict)
elif isinstance(obj.first, BasePromptTemplate):
_add_metadata(obj=obj.first, repo_dict=res_dict)
return obj
# Then it's < 0.1.15