mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 04:55:14 +00:00
infra: add -p to mkdir in lint steps (#17013)
Previously, if this did not find a mypy cache then it wouldnt run this makes it always run adding mypy ignore comments with existing uncaught issues to unblock other prs --------- Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
@@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(
|
||||
def format_response_payload( # type: ignore[override]
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase):
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
ContentFormatterBase.escape_special_characters(prompt)
|
||||
@@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(
|
||||
def format_response_payload( # type: ignore[override]
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]["generated_text"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime]
|
||||
|
||||
def format_request_payload(
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
prompt = ContentFormatterBase.escape_special_characters(prompt)
|
||||
@@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(
|
||||
def format_response_payload( # type: ignore[override]
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
try:
|
||||
choice = json.loads(output)[0]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
return Generation(text=choice)
|
||||
|
||||
|
||||
@@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
|
||||
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
|
||||
|
||||
def format_request_payload(
|
||||
def format_request_payload( # type: ignore[override]
|
||||
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
|
||||
) -> bytes:
|
||||
"""Formats the request according to the chosen api"""
|
||||
@@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
)
|
||||
return str.encode(request_payload)
|
||||
|
||||
def format_response_payload(
|
||||
def format_response_payload( # type: ignore[override]
|
||||
self, output: bytes, api_type: AzureMLEndpointApiType
|
||||
) -> Generation:
|
||||
"""Formats response"""
|
||||
@@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
try:
|
||||
choice = json.loads(output)[0]["0"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
return Generation(text=choice)
|
||||
if api_type == AzureMLEndpointApiType.serverless:
|
||||
try:
|
||||
@@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase):
|
||||
"received."
|
||||
)
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
|
||||
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
|
||||
return Generation(
|
||||
text=choice["text"].strip(),
|
||||
generation_info=dict(
|
||||
@@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
) -> AzureMLEndpointApiType:
|
||||
"""Validate that endpoint api type is compatible with the URL format."""
|
||||
endpoint_url = values.get("endpoint_url")
|
||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
|
||||
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
|
||||
"/score"
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
|
||||
)
|
||||
if field_value == AzureMLEndpointApiType.serverless and not (
|
||||
endpoint_url.endswith("/v1/completions")
|
||||
or endpoint_url.endswith("/v1/chat/completions")
|
||||
endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
|
||||
or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
|
||||
):
|
||||
raise ValueError(
|
||||
"Endpoints of type `serverless` should follow the format "
|
||||
@@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel):
|
||||
deployment_name = values.get("deployment_name")
|
||||
|
||||
http_client = AzureMLEndpointClient(
|
||||
endpoint_url, endpoint_key.get_secret_value(), deployment_name
|
||||
endpoint_url, # type: ignore
|
||||
endpoint_key.get_secret_value(), # type: ignore
|
||||
deployment_name, # type: ignore
|
||||
)
|
||||
return http_client
|
||||
|
||||
|
Reference in New Issue
Block a user