mirror of
				https://github.com/k3s-io/kubernetes.git
				synced 2025-11-03 23:40:03 +00:00 
			
		
		
		
	https://github.com/containerd/ttrpc/compare/v1.2.1...v1.2.2 Signed-off-by: Humble Chirammal <humble.devassy@gmail.com>
		
			
				
	
	
		
			513 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			513 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
/*
 | 
						|
   Copyright The containerd Authors.
 | 
						|
 | 
						|
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
   you may not use this file except in compliance with the License.
 | 
						|
   You may obtain a copy of the License at
 | 
						|
 | 
						|
       http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
   Unless required by applicable law or agreed to in writing, software
 | 
						|
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
   See the License for the specific language governing permissions and
 | 
						|
   limitations under the License.
 | 
						|
*/
 | 
						|
 | 
						|
package ttrpc
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
	"syscall"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/sirupsen/logrus"
 | 
						|
	"google.golang.org/grpc/codes"
 | 
						|
	"google.golang.org/grpc/status"
 | 
						|
	"google.golang.org/protobuf/proto"
 | 
						|
)
 | 
						|
 | 
						|
// Client for a ttrpc server
 | 
						|
type Client struct {
 | 
						|
	codec   codec
 | 
						|
	conn    net.Conn
 | 
						|
	channel *channel
 | 
						|
 | 
						|
	streamLock   sync.RWMutex
 | 
						|
	streams      map[streamID]*stream
 | 
						|
	nextStreamID streamID
 | 
						|
	sendLock     sync.Mutex
 | 
						|
 | 
						|
	ctx    context.Context
 | 
						|
	closed func()
 | 
						|
 | 
						|
	closeOnce       sync.Once
 | 
						|
	userCloseFunc   func()
 | 
						|
	userCloseWaitCh chan struct{}
 | 
						|
 | 
						|
	interceptor UnaryClientInterceptor
 | 
						|
}
 | 
						|
 | 
						|
// ClientOpts configures a client
 | 
						|
type ClientOpts func(c *Client)
 | 
						|
 | 
						|
// WithOnClose sets the close func whenever the client's Close() method is called
 | 
						|
func WithOnClose(onClose func()) ClientOpts {
 | 
						|
	return func(c *Client) {
 | 
						|
		c.userCloseFunc = onClose
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// WithUnaryClientInterceptor sets the provided client interceptor
 | 
						|
func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
 | 
						|
	return func(c *Client) {
 | 
						|
		c.interceptor = i
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// NewClient creates a new ttrpc client using the given connection
 | 
						|
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
 | 
						|
	ctx, cancel := context.WithCancel(context.Background())
 | 
						|
	channel := newChannel(conn)
 | 
						|
	c := &Client{
 | 
						|
		codec:           codec{},
 | 
						|
		conn:            conn,
 | 
						|
		channel:         channel,
 | 
						|
		streams:         make(map[streamID]*stream),
 | 
						|
		nextStreamID:    1,
 | 
						|
		closed:          cancel,
 | 
						|
		ctx:             ctx,
 | 
						|
		userCloseFunc:   func() {},
 | 
						|
		userCloseWaitCh: make(chan struct{}),
 | 
						|
		interceptor:     defaultClientInterceptor,
 | 
						|
	}
 | 
						|
 | 
						|
	for _, o := range opts {
 | 
						|
		o(c)
 | 
						|
	}
 | 
						|
 | 
						|
	go c.run()
 | 
						|
	return c
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) send(sid uint32, mt messageType, flags uint8, b []byte) error {
 | 
						|
	c.sendLock.Lock()
 | 
						|
	defer c.sendLock.Unlock()
 | 
						|
	return c.channel.send(sid, mt, flags, b)
 | 
						|
}
 | 
						|
 | 
						|
// Call makes a unary request and returns with response
 | 
						|
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
 | 
						|
	payload, err := c.codec.Marshal(req)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	var (
 | 
						|
		creq = &Request{
 | 
						|
			Service: service,
 | 
						|
			Method:  method,
 | 
						|
			Payload: payload,
 | 
						|
			// TODO: metadata from context
 | 
						|
		}
 | 
						|
 | 
						|
		cresp = &Response{}
 | 
						|
	)
 | 
						|
 | 
						|
	if metadata, ok := GetMetadata(ctx); ok {
 | 
						|
		metadata.setRequest(creq)
 | 
						|
	}
 | 
						|
 | 
						|
	if dl, ok := ctx.Deadline(); ok {
 | 
						|
		creq.TimeoutNano = time.Until(dl).Nanoseconds()
 | 
						|
	}
 | 
						|
 | 
						|
	info := &UnaryClientInfo{
 | 
						|
		FullMethod: fullPath(service, method),
 | 
						|
	}
 | 
						|
	if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
 | 
						|
		return status.ErrorProto(cresp.Status)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// StreamDesc describes the stream properties, whether the stream has
 | 
						|
// a streaming client, a streaming server, or both
 | 
						|
type StreamDesc struct {
 | 
						|
	StreamingClient bool
 | 
						|
	StreamingServer bool
 | 
						|
}
 | 
						|
 | 
						|
// ClientStream is used to send or recv messages on the underlying stream
 | 
						|
type ClientStream interface {
 | 
						|
	CloseSend() error
 | 
						|
	SendMsg(m interface{}) error
 | 
						|
	RecvMsg(m interface{}) error
 | 
						|
}
 | 
						|
 | 
						|
type clientStream struct {
 | 
						|
	ctx          context.Context
 | 
						|
	s            *stream
 | 
						|
	c            *Client
 | 
						|
	desc         *StreamDesc
 | 
						|
	localClosed  bool
 | 
						|
	remoteClosed bool
 | 
						|
}
 | 
						|
 | 
						|
func (cs *clientStream) CloseSend() error {
 | 
						|
	if !cs.desc.StreamingClient {
 | 
						|
		return fmt.Errorf("%w: cannot close non-streaming client", ErrProtocol)
 | 
						|
	}
 | 
						|
	if cs.localClosed {
 | 
						|
		return ErrStreamClosed
 | 
						|
	}
 | 
						|
	err := cs.s.send(messageTypeData, flagRemoteClosed|flagNoData, nil)
 | 
						|
	if err != nil {
 | 
						|
		return filterCloseErr(err)
 | 
						|
	}
 | 
						|
	cs.localClosed = true
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (cs *clientStream) SendMsg(m interface{}) error {
 | 
						|
	if !cs.desc.StreamingClient {
 | 
						|
		return fmt.Errorf("%w: cannot send data from non-streaming client", ErrProtocol)
 | 
						|
	}
 | 
						|
	if cs.localClosed {
 | 
						|
		return ErrStreamClosed
 | 
						|
	}
 | 
						|
 | 
						|
	var (
 | 
						|
		payload []byte
 | 
						|
		err     error
 | 
						|
	)
 | 
						|
	if m != nil {
 | 
						|
		payload, err = cs.c.codec.Marshal(m)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	err = cs.s.send(messageTypeData, 0, payload)
 | 
						|
	if err != nil {
 | 
						|
		return filterCloseErr(err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (cs *clientStream) RecvMsg(m interface{}) error {
 | 
						|
	if cs.remoteClosed {
 | 
						|
		return io.EOF
 | 
						|
	}
 | 
						|
 | 
						|
	var msg *streamMessage
 | 
						|
	select {
 | 
						|
	case <-cs.ctx.Done():
 | 
						|
		return cs.ctx.Err()
 | 
						|
	case <-cs.s.recvClose:
 | 
						|
		// If recv has a pending message, process that first
 | 
						|
		select {
 | 
						|
		case msg = <-cs.s.recv:
 | 
						|
		default:
 | 
						|
			return cs.s.recvErr
 | 
						|
		}
 | 
						|
	case msg = <-cs.s.recv:
 | 
						|
	}
 | 
						|
 | 
						|
	if msg.header.Type == messageTypeResponse {
 | 
						|
		resp := &Response{}
 | 
						|
		err := proto.Unmarshal(msg.payload[:msg.header.Length], resp)
 | 
						|
		// return the payload buffer for reuse
 | 
						|
		cs.c.channel.putmbuf(msg.payload)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		if err := cs.c.codec.Unmarshal(resp.Payload, m); err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		if resp.Status != nil && resp.Status.Code != int32(codes.OK) {
 | 
						|
			return status.ErrorProto(resp.Status)
 | 
						|
		}
 | 
						|
 | 
						|
		cs.c.deleteStream(cs.s)
 | 
						|
		cs.remoteClosed = true
 | 
						|
 | 
						|
		return nil
 | 
						|
	} else if msg.header.Type == messageTypeData {
 | 
						|
		if !cs.desc.StreamingServer {
 | 
						|
			cs.c.deleteStream(cs.s)
 | 
						|
			cs.remoteClosed = true
 | 
						|
			return fmt.Errorf("received data from non-streaming server: %w", ErrProtocol)
 | 
						|
		}
 | 
						|
		if msg.header.Flags&flagRemoteClosed == flagRemoteClosed {
 | 
						|
			cs.c.deleteStream(cs.s)
 | 
						|
			cs.remoteClosed = true
 | 
						|
 | 
						|
			if msg.header.Flags&flagNoData == flagNoData {
 | 
						|
				return io.EOF
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		err := cs.c.codec.Unmarshal(msg.payload[:msg.header.Length], m)
 | 
						|
		cs.c.channel.putmbuf(msg.payload)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	return fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
 | 
						|
}
 | 
						|
 | 
						|
// Close closes the ttrpc connection and underlying connection
 | 
						|
func (c *Client) Close() error {
 | 
						|
	c.closeOnce.Do(func() {
 | 
						|
		c.closed()
 | 
						|
 | 
						|
		c.conn.Close()
 | 
						|
	})
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// UserOnCloseWait is used to blocks untils the user's on-close callback
 | 
						|
// finishes.
 | 
						|
func (c *Client) UserOnCloseWait(ctx context.Context) error {
 | 
						|
	select {
 | 
						|
	case <-c.userCloseWaitCh:
 | 
						|
		return nil
 | 
						|
	case <-ctx.Done():
 | 
						|
		return ctx.Err()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) run() {
 | 
						|
	err := c.receiveLoop()
 | 
						|
	c.Close()
 | 
						|
	c.cleanupStreams(err)
 | 
						|
 | 
						|
	c.userCloseFunc()
 | 
						|
	close(c.userCloseWaitCh)
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) receiveLoop() error {
 | 
						|
	for {
 | 
						|
		select {
 | 
						|
		case <-c.ctx.Done():
 | 
						|
			return ErrClosed
 | 
						|
		default:
 | 
						|
			var (
 | 
						|
				msg = &streamMessage{}
 | 
						|
				err error
 | 
						|
			)
 | 
						|
 | 
						|
			msg.header, msg.payload, err = c.channel.recv()
 | 
						|
			if err != nil {
 | 
						|
				_, ok := status.FromError(err)
 | 
						|
				if !ok {
 | 
						|
					// treat all errors that are not an rpc status as terminal.
 | 
						|
					// all others poison the connection.
 | 
						|
					return filterCloseErr(err)
 | 
						|
				}
 | 
						|
			}
 | 
						|
			sid := streamID(msg.header.StreamID)
 | 
						|
			s := c.getStream(sid)
 | 
						|
			if s == nil {
 | 
						|
				logrus.WithField("stream", sid).Errorf("ttrpc: received message on inactive stream")
 | 
						|
				continue
 | 
						|
			}
 | 
						|
 | 
						|
			if err != nil {
 | 
						|
				s.closeWithError(err)
 | 
						|
			} else {
 | 
						|
				if err := s.receive(c.ctx, msg); err != nil {
 | 
						|
					logrus.WithError(err).WithField("stream", sid).Errorf("ttrpc: failed to handle message")
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// createStream creates a new stream and registers it with the client
 | 
						|
// Introduce stream types for multiple or single response
 | 
						|
func (c *Client) createStream(flags uint8, b []byte) (*stream, error) {
 | 
						|
	c.streamLock.Lock()
 | 
						|
 | 
						|
	// Check if closed since lock acquired to prevent adding
 | 
						|
	// anything after cleanup completes
 | 
						|
	select {
 | 
						|
	case <-c.ctx.Done():
 | 
						|
		c.streamLock.Unlock()
 | 
						|
		return nil, ErrClosed
 | 
						|
	default:
 | 
						|
	}
 | 
						|
 | 
						|
	// Stream ID should be allocated at same time
 | 
						|
	s := newStream(c.nextStreamID, c)
 | 
						|
	c.streams[s.id] = s
 | 
						|
	c.nextStreamID = c.nextStreamID + 2
 | 
						|
 | 
						|
	c.sendLock.Lock()
 | 
						|
	defer c.sendLock.Unlock()
 | 
						|
	c.streamLock.Unlock()
 | 
						|
 | 
						|
	if err := c.channel.send(uint32(s.id), messageTypeRequest, flags, b); err != nil {
 | 
						|
		return s, filterCloseErr(err)
 | 
						|
	}
 | 
						|
 | 
						|
	return s, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) deleteStream(s *stream) {
 | 
						|
	c.streamLock.Lock()
 | 
						|
	delete(c.streams, s.id)
 | 
						|
	c.streamLock.Unlock()
 | 
						|
	s.closeWithError(nil)
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) getStream(sid streamID) *stream {
 | 
						|
	c.streamLock.RLock()
 | 
						|
	s := c.streams[sid]
 | 
						|
	c.streamLock.RUnlock()
 | 
						|
	return s
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) cleanupStreams(err error) {
 | 
						|
	c.streamLock.Lock()
 | 
						|
	defer c.streamLock.Unlock()
 | 
						|
 | 
						|
	for sid, s := range c.streams {
 | 
						|
		s.closeWithError(err)
 | 
						|
		delete(c.streams, sid)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// filterCloseErr rewrites EOF and EPIPE errors to ErrClosed. Use when
 | 
						|
// returning from call or handling errors from main read loop.
 | 
						|
//
 | 
						|
// This purposely ignores errors with a wrapped cause.
 | 
						|
func filterCloseErr(err error) error {
 | 
						|
	switch {
 | 
						|
	case err == nil:
 | 
						|
		return nil
 | 
						|
	case err == io.EOF:
 | 
						|
		return ErrClosed
 | 
						|
	case errors.Is(err, io.ErrClosedPipe):
 | 
						|
		return ErrClosed
 | 
						|
	case errors.Is(err, io.EOF):
 | 
						|
		return ErrClosed
 | 
						|
	case strings.Contains(err.Error(), "use of closed network connection"):
 | 
						|
		return ErrClosed
 | 
						|
	default:
 | 
						|
		// if we have an epipe on a write or econnreset on a read , we cast to errclosed
 | 
						|
		var oerr *net.OpError
 | 
						|
		if errors.As(err, &oerr) {
 | 
						|
			if (oerr.Op == "write" && errors.Is(err, syscall.EPIPE)) ||
 | 
						|
				(oerr.Op == "read" && errors.Is(err, syscall.ECONNRESET)) {
 | 
						|
				return ErrClosed
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
// NewStream creates a new stream with the given stream descriptor to the
 | 
						|
// specified service and method. If not a streaming client, the request object
 | 
						|
// may be provided.
 | 
						|
func (c *Client) NewStream(ctx context.Context, desc *StreamDesc, service, method string, req interface{}) (ClientStream, error) {
 | 
						|
	var payload []byte
 | 
						|
	if req != nil {
 | 
						|
		var err error
 | 
						|
		payload, err = c.codec.Marshal(req)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	request := &Request{
 | 
						|
		Service: service,
 | 
						|
		Method:  method,
 | 
						|
		Payload: payload,
 | 
						|
		// TODO: metadata from context
 | 
						|
	}
 | 
						|
	p, err := c.codec.Marshal(request)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	var flags uint8
 | 
						|
	if desc.StreamingClient {
 | 
						|
		flags = flagRemoteOpen
 | 
						|
	} else {
 | 
						|
		flags = flagRemoteClosed
 | 
						|
	}
 | 
						|
	s, err := c.createStream(flags, p)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &clientStream{
 | 
						|
		ctx:  ctx,
 | 
						|
		s:    s,
 | 
						|
		c:    c,
 | 
						|
		desc: desc,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
 | 
						|
	p, err := c.codec.Marshal(req)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	s, err := c.createStream(0, p)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	defer c.deleteStream(s)
 | 
						|
 | 
						|
	var msg *streamMessage
 | 
						|
	select {
 | 
						|
	case <-ctx.Done():
 | 
						|
		return ctx.Err()
 | 
						|
	case <-c.ctx.Done():
 | 
						|
		return ErrClosed
 | 
						|
	case <-s.recvClose:
 | 
						|
		// If recv has a pending message, process that first
 | 
						|
		select {
 | 
						|
		case msg = <-s.recv:
 | 
						|
		default:
 | 
						|
			return s.recvErr
 | 
						|
		}
 | 
						|
	case msg = <-s.recv:
 | 
						|
	}
 | 
						|
 | 
						|
	if msg.header.Type == messageTypeResponse {
 | 
						|
		err = proto.Unmarshal(msg.payload[:msg.header.Length], resp)
 | 
						|
	} else {
 | 
						|
		err = fmt.Errorf("unexpected %q message received: %w", msg.header.Type, ErrProtocol)
 | 
						|
	}
 | 
						|
 | 
						|
	// return the payload buffer for reuse
 | 
						|
	c.channel.putmbuf(msg.payload)
 | 
						|
 | 
						|
	return err
 | 
						|
}
 |