mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
Fix Async Shared Resource Bug (#4751)
Use an async queue to distribute tracers rather than inappropriately sharing a single one
This commit is contained in:
parent
3f0357f94a
commit
a128d95aeb
@ -26,7 +26,6 @@ from pydantic import BaseSettings, Field, root_validator
|
||||
from requests import Response
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
@ -351,14 +350,13 @@ class LangChainPlusClient(BaseSettings):
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
async def _gather_with_concurrency(
|
||||
n: int,
|
||||
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]],
|
||||
initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]],
|
||||
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
|
||||
) -> List[Any]:
|
||||
"""
|
||||
@ -373,21 +371,28 @@ class LangChainPlusClient(BaseSettings):
|
||||
A list of results from the coroutines.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
tracer, job_state = await initializer()
|
||||
job_state = {"num_processed": 0}
|
||||
|
||||
tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue()
|
||||
for _ in range(n):
|
||||
tracer_queue.put_nowait(await initializer())
|
||||
|
||||
async def run_coroutine_with_semaphore(
|
||||
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
|
||||
) -> Any:
|
||||
async with semaphore:
|
||||
return await async_func(tracer, job_state)
|
||||
tracer = await tracer_queue.get()
|
||||
try:
|
||||
result = await async_func(tracer, job_state)
|
||||
finally:
|
||||
tracer_queue.put_nowait(tracer)
|
||||
return result
|
||||
|
||||
return await asyncio.gather(
|
||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||
)
|
||||
|
||||
async def _tracer_initializer(
|
||||
self, session_name: str
|
||||
) -> Tuple[LangChainTracer, dict]:
|
||||
async def _tracer_initializer(self, session_name: str) -> LangChainTracer:
|
||||
"""
|
||||
Initialize a tracer to share across tasks.
|
||||
|
||||
@ -397,11 +402,9 @@ class LangChainPlusClient(BaseSettings):
|
||||
Returns:
|
||||
A LangChainTracer instance with an active session.
|
||||
"""
|
||||
job_state = {"num_processed": 0}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
tracer = LangChainTracer()
|
||||
tracer.session = session
|
||||
return tracer, job_state
|
||||
tracer = LangChainTracer(session_name=session_name)
|
||||
tracer.ensure_session()
|
||||
return tracer
|
||||
|
||||
async def arun_on_dataset(
|
||||
self,
|
||||
@ -513,8 +516,7 @@ class LangChainPlusClient(BaseSettings):
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||
outputs.append({"Error": str(e)})
|
||||
finally:
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
langchain_tracer.example_id = previous_example_id
|
||||
return outputs
|
||||
|
||||
def run_on_dataset(
|
||||
@ -550,18 +552,16 @@ class LangChainPlusClient(BaseSettings):
|
||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
||||
results: Dict[str, Any] = {}
|
||||
with tracing_v2_enabled(session_name=session_name) as session:
|
||||
tracer = LangChainTracer()
|
||||
tracer.session = session
|
||||
|
||||
for i, example in enumerate(examples):
|
||||
result = self.run_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
tracer = LangChainTracer(session_name=session_name)
|
||||
tracer.ensure_session()
|
||||
for i, example in enumerate(examples):
|
||||
result = self.run_llm_or_chain(
|
||||
example,
|
||||
tracer,
|
||||
llm_or_chain_factory,
|
||||
num_repetitions,
|
||||
)
|
||||
if verbose:
|
||||
print(f"{i+1} processed", flush=True, end="\r")
|
||||
results[str(example.id)] = result
|
||||
return results
|
||||
|
Loading…
Reference in New Issue
Block a user