mirror of
https://github.com/haiwen/seafile-server.git
synced 2025-06-30 08:51:50 +00:00
* Fix crash when concurrent close channel * Add ErrCh to notify client's goruntine to exit * Add WaitGroup to wait for all goruntine to exit * Call Signal on error --------- Co-authored-by: 杨赫然 <heran.yang@seafile.com>
306 lines
6.4 KiB
Go
306 lines
6.4 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"time"
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
"github.com/gorilla/websocket"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
writeWait = 1 * time.Second
|
|
pongWait = 5 * time.Second
|
|
// Send pings to peer with this period. Must be less than pongWait.
|
|
pingPeriod = 1 * time.Second
|
|
|
|
checkTokenPeriod = 1 * time.Hour
|
|
)
|
|
|
|
// Message is the message communicated between clients and server.
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Content json.RawMessage `json:"content"`
|
|
}
|
|
|
|
type SubList struct {
|
|
Repos []Repo `json:"repos"`
|
|
}
|
|
|
|
type UnsubList struct {
|
|
Repos []Repo `json:"repos"`
|
|
}
|
|
|
|
type Repo struct {
|
|
RepoID string `json:"id"`
|
|
Token string `json:"jwt_token"`
|
|
}
|
|
|
|
type myClaims struct {
|
|
Exp int64
|
|
RepoID string `json:"repo_id"`
|
|
UserName string `json:"username"`
|
|
}
|
|
|
|
func (*myClaims) Valid() error {
|
|
return nil
|
|
}
|
|
|
|
func (client *Client) Close() {
|
|
client.conn.Close()
|
|
}
|
|
|
|
func RecoverWrapper(f func()) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
log.Printf("panic: %v\n%s", err, debug.Stack())
|
|
}
|
|
}()
|
|
|
|
f()
|
|
}
|
|
|
|
// HandleMessages connects to the client to process message.
|
|
func (client *Client) HandleMessages() {
|
|
// Set keep alive.
|
|
client.conn.SetPongHandler(func(string) error {
|
|
client.Alive = time.Now()
|
|
return nil
|
|
})
|
|
|
|
client.ConnCloser.AddRunning(4)
|
|
go RecoverWrapper(client.readMessages)
|
|
go RecoverWrapper(client.writeMessages)
|
|
go RecoverWrapper(client.checkTokenExpired)
|
|
go RecoverWrapper(client.keepAlive)
|
|
client.ConnCloser.Wait()
|
|
client.Close()
|
|
UnregisterClient(client)
|
|
for id := range client.Repos {
|
|
client.unsubscribe(id)
|
|
}
|
|
}
|
|
|
|
func (client *Client) readMessages() {
|
|
conn := client.conn
|
|
defer func() {
|
|
client.ConnCloser.Done()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-client.ConnCloser.HasBeenClosed():
|
|
return
|
|
default:
|
|
}
|
|
var msg Message
|
|
err := conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
client.ConnCloser.Signal()
|
|
log.Debugf("failed to read json data from client: %s: %v", client.Addr, err)
|
|
return
|
|
}
|
|
|
|
err = client.handleMessage(&msg)
|
|
if err != nil {
|
|
client.ConnCloser.Signal()
|
|
log.Debugf("%v", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func checkToken(tokenString, repoID string) (string, int64, bool) {
|
|
if len(tokenString) == 0 {
|
|
return "", -1, false
|
|
}
|
|
claims := new(myClaims)
|
|
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(privateKey), nil
|
|
})
|
|
if err != nil {
|
|
return "", -1, false
|
|
}
|
|
|
|
if !token.Valid {
|
|
return "", -1, false
|
|
}
|
|
|
|
now := time.Now()
|
|
if claims.RepoID != repoID || claims.Exp <= now.Unix() {
|
|
return "", -1, false
|
|
}
|
|
|
|
return claims.UserName, claims.Exp, true
|
|
}
|
|
|
|
func (client *Client) handleMessage(msg *Message) error {
|
|
content := msg.Content
|
|
|
|
if msg.Type == "subscribe" {
|
|
var list SubList
|
|
err := json.Unmarshal(content, &list)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, repo := range list.Repos {
|
|
user, exp, valid := checkToken(repo.Token, repo.RepoID)
|
|
if !valid {
|
|
client.notifJWTExpired(repo.RepoID)
|
|
continue
|
|
}
|
|
client.subscribe(repo.RepoID, user, exp)
|
|
}
|
|
} else if msg.Type == "unsubscribe" {
|
|
var list UnsubList
|
|
err := json.Unmarshal(content, &list)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, r := range list.Repos {
|
|
client.unsubscribe(r.RepoID)
|
|
}
|
|
} else {
|
|
err := fmt.Errorf("recv unexpected type of message: %s", msg.Type)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// subscribe subscribes to notifications of repos.
|
|
func (client *Client) subscribe(repoID, user string, exp int64) {
|
|
client.User = user
|
|
|
|
client.ReposMutex.Lock()
|
|
client.Repos[repoID] = exp
|
|
client.ReposMutex.Unlock()
|
|
|
|
subMutex.Lock()
|
|
subscribers, ok := subscriptions[repoID]
|
|
if !ok {
|
|
subscribers = newSubscribers(client)
|
|
subscriptions[repoID] = subscribers
|
|
}
|
|
subMutex.Unlock()
|
|
|
|
subscribers.Mutex.Lock()
|
|
subscribers.Clients[client.ID] = client
|
|
subscribers.Mutex.Unlock()
|
|
}
|
|
|
|
func (client *Client) unsubscribe(repoID string) {
|
|
client.ReposMutex.Lock()
|
|
delete(client.Repos, repoID)
|
|
client.ReposMutex.Unlock()
|
|
|
|
subMutex.Lock()
|
|
subscribers, ok := subscriptions[repoID]
|
|
if !ok {
|
|
subMutex.Unlock()
|
|
return
|
|
}
|
|
subMutex.Unlock()
|
|
|
|
subscribers.Mutex.Lock()
|
|
delete(subscribers.Clients, client.ID)
|
|
subscribers.Mutex.Unlock()
|
|
|
|
}
|
|
|
|
func (client *Client) writeMessages() {
|
|
defer func() {
|
|
client.ConnCloser.Done()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case msg := <-client.WCh:
|
|
client.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
client.connMutex.Lock()
|
|
err := client.conn.WriteJSON(msg)
|
|
client.connMutex.Unlock()
|
|
if err != nil {
|
|
client.ConnCloser.Signal()
|
|
log.Debugf("failed to send notification to client: %v", err)
|
|
return
|
|
}
|
|
m, _ := msg.(*Message)
|
|
log.Debugf("send %s event to client %s(%d): %s", m.Type, client.User, client.ID, string(m.Content))
|
|
case <-client.ConnCloser.HasBeenClosed():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *Client) keepAlive() {
|
|
defer func() {
|
|
client.ConnCloser.Done()
|
|
}()
|
|
|
|
ticker := time.NewTicker(pingPeriod)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
if time.Since(client.Alive) > pongWait {
|
|
client.ConnCloser.Signal()
|
|
log.Debugf("disconnected because no pong was received for more than %v", pongWait)
|
|
return
|
|
}
|
|
client.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
client.connMutex.Lock()
|
|
err := client.conn.WriteMessage(websocket.PingMessage, nil)
|
|
client.connMutex.Unlock()
|
|
if err != nil {
|
|
client.ConnCloser.Signal()
|
|
log.Debugf("failed to send ping message to client: %v", err)
|
|
return
|
|
}
|
|
case <-client.ConnCloser.HasBeenClosed():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *Client) checkTokenExpired() {
|
|
defer func() {
|
|
client.ConnCloser.Done()
|
|
}()
|
|
|
|
ticker := time.NewTicker(checkTokenPeriod)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
// unsubscribe will delete repo from client.Repos, we'd better unsubscribe repos later.
|
|
pendingRepos := make(map[string]struct{})
|
|
now := time.Now()
|
|
client.ReposMutex.Lock()
|
|
for repoID, exp := range client.Repos {
|
|
if exp >= now.Unix() {
|
|
continue
|
|
}
|
|
pendingRepos[repoID] = struct{}{}
|
|
}
|
|
client.ReposMutex.Unlock()
|
|
|
|
for repoID := range pendingRepos {
|
|
client.unsubscribe(repoID)
|
|
client.notifJWTExpired(repoID)
|
|
}
|
|
case <-client.ConnCloser.HasBeenClosed():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (client *Client) notifJWTExpired(repoID string) {
|
|
msg := new(Message)
|
|
msg.Type = "jwt-expired"
|
|
content := fmt.Sprintf("{\"repo_id\":\"%s\"}", repoID)
|
|
msg.Content = []byte(content)
|
|
client.WCh <- msg
|
|
}
|