1
0
mirror of https://github.com/haiwen/seafile-server.git synced 2025-06-30 08:51:50 +00:00
seafile-server/notification-server/client.go
feiniks 54ecfbee42
Fix crash when concurrent close channel (#612)
* Fix crash when concurrent close channel

* Add ErrCh to notify client's goruntine to exit

* Add WaitGroup to wait for all goruntine to exit

* Call Signal on error

---------

Co-authored-by: 杨赫然 <heran.yang@seafile.com>
2023-05-09 18:10:01 +08:00

306 lines
6.4 KiB
Go

package main
import (
"encoding/json"
"fmt"
"runtime/debug"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
const (
writeWait = 1 * time.Second
pongWait = 5 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = 1 * time.Second
checkTokenPeriod = 1 * time.Hour
)
// Message is the message communicated between clients and server.
type Message struct {
Type string `json:"type"`
Content json.RawMessage `json:"content"`
}
type SubList struct {
Repos []Repo `json:"repos"`
}
type UnsubList struct {
Repos []Repo `json:"repos"`
}
type Repo struct {
RepoID string `json:"id"`
Token string `json:"jwt_token"`
}
type myClaims struct {
Exp int64
RepoID string `json:"repo_id"`
UserName string `json:"username"`
}
func (*myClaims) Valid() error {
return nil
}
func (client *Client) Close() {
client.conn.Close()
}
func RecoverWrapper(f func()) {
defer func() {
if err := recover(); err != nil {
log.Printf("panic: %v\n%s", err, debug.Stack())
}
}()
f()
}
// HandleMessages connects to the client to process message.
func (client *Client) HandleMessages() {
// Set keep alive.
client.conn.SetPongHandler(func(string) error {
client.Alive = time.Now()
return nil
})
client.ConnCloser.AddRunning(4)
go RecoverWrapper(client.readMessages)
go RecoverWrapper(client.writeMessages)
go RecoverWrapper(client.checkTokenExpired)
go RecoverWrapper(client.keepAlive)
client.ConnCloser.Wait()
client.Close()
UnregisterClient(client)
for id := range client.Repos {
client.unsubscribe(id)
}
}
func (client *Client) readMessages() {
conn := client.conn
defer func() {
client.ConnCloser.Done()
}()
for {
select {
case <-client.ConnCloser.HasBeenClosed():
return
default:
}
var msg Message
err := conn.ReadJSON(&msg)
if err != nil {
client.ConnCloser.Signal()
log.Debugf("failed to read json data from client: %s: %v", client.Addr, err)
return
}
err = client.handleMessage(&msg)
if err != nil {
client.ConnCloser.Signal()
log.Debugf("%v", err)
return
}
}
}
func checkToken(tokenString, repoID string) (string, int64, bool) {
if len(tokenString) == 0 {
return "", -1, false
}
claims := new(myClaims)
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
return []byte(privateKey), nil
})
if err != nil {
return "", -1, false
}
if !token.Valid {
return "", -1, false
}
now := time.Now()
if claims.RepoID != repoID || claims.Exp <= now.Unix() {
return "", -1, false
}
return claims.UserName, claims.Exp, true
}
func (client *Client) handleMessage(msg *Message) error {
content := msg.Content
if msg.Type == "subscribe" {
var list SubList
err := json.Unmarshal(content, &list)
if err != nil {
return err
}
for _, repo := range list.Repos {
user, exp, valid := checkToken(repo.Token, repo.RepoID)
if !valid {
client.notifJWTExpired(repo.RepoID)
continue
}
client.subscribe(repo.RepoID, user, exp)
}
} else if msg.Type == "unsubscribe" {
var list UnsubList
err := json.Unmarshal(content, &list)
if err != nil {
return err
}
for _, r := range list.Repos {
client.unsubscribe(r.RepoID)
}
} else {
err := fmt.Errorf("recv unexpected type of message: %s", msg.Type)
return err
}
return nil
}
// subscribe subscribes to notifications of repos.
func (client *Client) subscribe(repoID, user string, exp int64) {
client.User = user
client.ReposMutex.Lock()
client.Repos[repoID] = exp
client.ReposMutex.Unlock()
subMutex.Lock()
subscribers, ok := subscriptions[repoID]
if !ok {
subscribers = newSubscribers(client)
subscriptions[repoID] = subscribers
}
subMutex.Unlock()
subscribers.Mutex.Lock()
subscribers.Clients[client.ID] = client
subscribers.Mutex.Unlock()
}
func (client *Client) unsubscribe(repoID string) {
client.ReposMutex.Lock()
delete(client.Repos, repoID)
client.ReposMutex.Unlock()
subMutex.Lock()
subscribers, ok := subscriptions[repoID]
if !ok {
subMutex.Unlock()
return
}
subMutex.Unlock()
subscribers.Mutex.Lock()
delete(subscribers.Clients, client.ID)
subscribers.Mutex.Unlock()
}
func (client *Client) writeMessages() {
defer func() {
client.ConnCloser.Done()
}()
for {
select {
case msg := <-client.WCh:
client.conn.SetWriteDeadline(time.Now().Add(writeWait))
client.connMutex.Lock()
err := client.conn.WriteJSON(msg)
client.connMutex.Unlock()
if err != nil {
client.ConnCloser.Signal()
log.Debugf("failed to send notification to client: %v", err)
return
}
m, _ := msg.(*Message)
log.Debugf("send %s event to client %s(%d): %s", m.Type, client.User, client.ID, string(m.Content))
case <-client.ConnCloser.HasBeenClosed():
return
}
}
}
func (client *Client) keepAlive() {
defer func() {
client.ConnCloser.Done()
}()
ticker := time.NewTicker(pingPeriod)
for {
select {
case <-ticker.C:
if time.Since(client.Alive) > pongWait {
client.ConnCloser.Signal()
log.Debugf("disconnected because no pong was received for more than %v", pongWait)
return
}
client.conn.SetWriteDeadline(time.Now().Add(writeWait))
client.connMutex.Lock()
err := client.conn.WriteMessage(websocket.PingMessage, nil)
client.connMutex.Unlock()
if err != nil {
client.ConnCloser.Signal()
log.Debugf("failed to send ping message to client: %v", err)
return
}
case <-client.ConnCloser.HasBeenClosed():
return
}
}
}
func (client *Client) checkTokenExpired() {
defer func() {
client.ConnCloser.Done()
}()
ticker := time.NewTicker(checkTokenPeriod)
for {
select {
case <-ticker.C:
// unsubscribe will delete repo from client.Repos, we'd better unsubscribe repos later.
pendingRepos := make(map[string]struct{})
now := time.Now()
client.ReposMutex.Lock()
for repoID, exp := range client.Repos {
if exp >= now.Unix() {
continue
}
pendingRepos[repoID] = struct{}{}
}
client.ReposMutex.Unlock()
for repoID := range pendingRepos {
client.unsubscribe(repoID)
client.notifJWTExpired(repoID)
}
case <-client.ConnCloser.HasBeenClosed():
return
}
}
}
func (client *Client) notifJWTExpired(repoID string) {
msg := new(Message)
msg.Type = "jwt-expired"
content := fmt.Sprintf("{\"repo_id\":\"%s\"}", repoID)
msg.Content = []byte(content)
client.WCh <- msg
}