mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -4,17 +4,16 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
mode='mean',
|
||||
mode="mean",
|
||||
include_last_offset=False,
|
||||
):
|
||||
super(BaseEmbeddingBag, self).__init__()
|
||||
@@ -22,9 +21,9 @@ class BaseEmbeddingBag(abc.ABC, nn.Module):
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings"
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||
assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings"
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
|
Reference in New Issue
Block a user