infra: add -p to mkdir in lint steps (#17013)

Previously, if this did not find a mypy cache then it wouldnt run

this makes it always run

adding mypy ignore comments with existing uncaught issues to unblock other prs

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Harrison Chase
2024-02-05 11:22:06 -08:00
committed by GitHub
parent db6af21395
commit 4eda647fdd
103 changed files with 378 additions and 369 deletions

View File

@@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
@@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
ContentFormatterBase.escape_special_characters(prompt)
@@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]["generated_text"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
@@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
"""Formats the request according to the chosen api"""
@@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
"""Formats response"""
@@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase):
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
if api_type == AzureMLEndpointApiType.serverless:
try:
@@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase):
"received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(
text=choice["text"].strip(),
generation_info=dict(
@@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
"/score"
):
raise ValueError(
@@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel):
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
)
if field_value == AzureMLEndpointApiType.serverless and not (
endpoint_url.endswith("/v1/completions")
or endpoint_url.endswith("/v1/chat/completions")
endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
):
raise ValueError(
"Endpoints of type `serverless` should follow the format "
@@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel):
deployment_name = values.get("deployment_name")
http_client = AzureMLEndpointClient(
endpoint_url, endpoint_key.get_secret_value(), deployment_name
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
)
return http_client