community[minor]: allow enabling proxy in aiohttp session in AsyncHTML (#19499)

Allow enabling proxy in aiohttp session async html
This commit is contained in:
Sihan Chen 2024-05-23 02:25:06 +08:00 committed by GitHub
parent 36813d2f00
commit 1f81277b9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 2 deletions

View File

@ -37,6 +37,10 @@
"source": [ "source": [
"urls = [\"https://www.espn.com\", \"https://lilianweng.github.io/posts/2023-06-23-agent/\"]\n", "urls = [\"https://www.espn.com\", \"https://lilianweng.github.io/posts/2023-06-23-agent/\"]\n",
"loader = AsyncHtmlLoader(urls)\n", "loader = AsyncHtmlLoader(urls)\n",
"# If you need to use the proxy to make web requests, for example using http_proxy/https_proxy environmental variables,\n",
"# please set trust_env=True explicitly here as follows:\n",
"# loader = AsyncHtmlLoader(urls, trust_env=True)\n",
"# Otherwise, loader.load() may stuck becuase aiohttp session does not recognize the proxy by default\n",
"docs = loader.load()" "docs = loader.load()"
] ]
}, },

View File

@ -64,6 +64,7 @@ class AsyncHtmlLoader(BaseLoader):
ignore_load_errors: bool = False, ignore_load_errors: bool = False,
*, *,
preserve_order: bool = True, preserve_order: bool = True,
trust_env: bool = False,
): ):
"""Initialize with a webpage path.""" """Initialize with a webpage path."""
@ -104,6 +105,8 @@ class AsyncHtmlLoader(BaseLoader):
self.ignore_load_errors = ignore_load_errors self.ignore_load_errors = ignore_load_errors
self.preserve_order = preserve_order self.preserve_order = preserve_order
self.trust_env = trust_env
def _fetch_valid_connection_docs(self, url: str) -> Any: def _fetch_valid_connection_docs(self, url: str) -> Any:
if self.ignore_load_errors: if self.ignore_load_errors:
try: try:
@ -126,7 +129,7 @@ class AsyncHtmlLoader(BaseLoader):
async def _fetch( async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5 self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str: ) -> str:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries): for i in range(retries):
try: try:
async with session.get( async with session.get(

View File

@ -75,6 +75,11 @@ class WebResearchRetriever(BaseRetriever):
url_database: List[str] = Field( url_database: List[str] = Field(
default_factory=list, description="List of processed URLs" default_factory=list, description="List of processed URLs"
) )
trust_env: bool = Field(
False,
description="Whether to use the http_proxy/https_proxy env variables or "
"check .netrc for proxy configuration",
)
@classmethod @classmethod
def from_llm( def from_llm(
@ -87,6 +92,7 @@ class WebResearchRetriever(BaseRetriever):
text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter(
chunk_size=1500, chunk_overlap=150 chunk_size=1500, chunk_overlap=150
), ),
trust_env: bool = False,
) -> "WebResearchRetriever": ) -> "WebResearchRetriever":
"""Initialize from llm using default template. """Initialize from llm using default template.
@ -97,6 +103,8 @@ class WebResearchRetriever(BaseRetriever):
prompt: prompt to generating search questions prompt: prompt to generating search questions
num_search_results: Number of pages per Google search num_search_results: Number of pages per Google search
text_splitter: Text splitter for splitting web pages into chunks text_splitter: Text splitter for splitting web pages into chunks
trust_env: Whether to use the http_proxy/https_proxy env variables
or check .netrc for proxy configuration
Returns: Returns:
WebResearchRetriever WebResearchRetriever
@ -124,6 +132,7 @@ class WebResearchRetriever(BaseRetriever):
search=search, search=search,
num_search_results=num_search_results, num_search_results=num_search_results,
text_splitter=text_splitter, text_splitter=text_splitter,
trust_env=trust_env,
) )
def clean_search_query(self, query: str) -> str: def clean_search_query(self, query: str) -> str:
@ -191,7 +200,9 @@ class WebResearchRetriever(BaseRetriever):
logger.info(f"New URLs to load: {new_urls}") logger.info(f"New URLs to load: {new_urls}")
# Load, split, and add new urls to vectorstore # Load, split, and add new urls to vectorstore
if new_urls: if new_urls:
loader = AsyncHtmlLoader(new_urls, ignore_load_errors=True) loader = AsyncHtmlLoader(
new_urls, ignore_load_errors=True, trust_env=self.trust_env
)
html2text = Html2TextTransformer() html2text = Html2TextTransformer()
logger.info("Indexing new urls...") logger.info("Indexing new urls...")
docs = loader.load() docs = loader.load()