ColossalAI/examples/tutorial/sequence_parallel/model/layers/pooler.py
BoxiangW ca6e75bc28
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI.

* Change names of handsons and edit sequence parallel example.

* Edit wrong folder name

* resolve conflict

* delete readme
2022-11-11 17:08:17 +08:00

29 lines
832 B
Python

import torch
import torch.nn as nn
from .linear import Linear
class Pooler(nn.Module):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = Linear(hidden_size, hidden_size)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled