community[minor]: allow enabling proxy in aiohttp session in AsyncHTML (#19499)

Allow enabling proxy in aiohttp session async html
This commit is contained in:
Sihan Chen
2024-05-23 02:25:06 +08:00
committed by GitHub
parent 36813d2f00
commit 1f81277b9b
3 changed files with 20 additions and 2 deletions

View File

@@ -64,6 +64,7 @@ class AsyncHtmlLoader(BaseLoader):
ignore_load_errors: bool = False,
*,
preserve_order: bool = True,
trust_env: bool = False,
):
"""Initialize with a webpage path."""
@@ -104,6 +105,8 @@ class AsyncHtmlLoader(BaseLoader):
self.ignore_load_errors = ignore_load_errors
self.preserve_order = preserve_order
self.trust_env = trust_env
def _fetch_valid_connection_docs(self, url: str) -> Any:
if self.ignore_load_errors:
try:
@@ -126,7 +129,7 @@ class AsyncHtmlLoader(BaseLoader):
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
async with session.get(

View File

@@ -75,6 +75,11 @@ class WebResearchRetriever(BaseRetriever):
url_database: List[str] = Field(
default_factory=list, description="List of processed URLs"
)
trust_env: bool = Field(
False,
description="Whether to use the http_proxy/https_proxy env variables or "
"check .netrc for proxy configuration",
)
@classmethod
def from_llm(
@@ -87,6 +92,7 @@ class WebResearchRetriever(BaseRetriever):
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=150
),
trust_env: bool = False,
) -> "WebResearchRetriever":
"""Initialize from llm using default template.
@@ -97,6 +103,8 @@ class WebResearchRetriever(BaseRetriever):
prompt: prompt to generating search questions
num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks
trust_env: Whether to use the http_proxy/https_proxy env variables
or check .netrc for proxy configuration
Returns:
WebResearchRetriever
@@ -124,6 +132,7 @@ class WebResearchRetriever(BaseRetriever):
search=search,
num_search_results=num_search_results,
text_splitter=text_splitter,
trust_env=trust_env,
)
def clean_search_query(self, query: str) -> str:
@@ -191,7 +200,9 @@ class WebResearchRetriever(BaseRetriever):
logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore
if new_urls:
loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True)
loader = AsyncHtmlLoader(
new_urls, ignore_load_errors=True, trust_env=self.trust_env
)
html2text = Html2TextTransformer()
logger.info("Indexing new urls...")
docs = loader.load()