vendor: upgrade vendor ttrpc

Upgrade vendor ttrpc to fix the issue of fd leak.

Fixes: #2000

    0e0f228 Handle ok status
    8c74fe8 Update to go 1.12x on travis
    17f4d32 Client.Call(): do not return error if no Status is set(gRPC v1.23 and up)
    271238a Fix method full name generation
    694de9d metadata as KeyValue type
    3afb82b Fix error handling with server shutdown
    f3eb35b Refactor close handling for ttrpc clients
    de8faac Add godocs for interceptors
    e409d7d Add example binary for testing the example service
    819653f Add client and server unary interceptors
    04523b9 Rename headers to metadata
    5926a92 Support headers
    911c9cd Improve connection error handling
    96dcf73 Handle EOF to prevent file descriptor leak
    ba15956 Make onclose an option.

Signed-off-by: lifupan <lifupan@gmail.com>
This commit is contained in:
lifupan 2019-09-12 14:25:39 +08:00
parent 39864c37ff
commit da4d89bd9a
10 changed files with 379 additions and 107 deletions

5
Gopkg.lock generated
View File

@ -144,12 +144,11 @@
revision = "7d11b49dc0769f6dbb0d1b19f3d48524d1bad9ad"
[[projects]]
branch = "master"
digest = "1:25435262330720ca0cade25af7ee7fb96d0cb70cc1ea0c0961694681c12a90e6"
digest = "1:c710f7b09759fe5ef0122a55a14afa8fab25a51a56e0826c24499563ef9e0e92"
name = "github.com/containerd/ttrpc"
packages = ["."]
pruneopts = "NUT"
revision = "69144327078caa5a2f1d5eda8bea6110bf16eeb3"
revision = "92c8520ef9f86600c650dd540266a007bf03670f"
[[projects]]
branch = "master"

View File

@ -70,6 +70,10 @@
name = "github.com/gogo/protobuf"
revision = "4cbf7e384e768b4e01799441fdf2a706a5635ae7"
[[override]]
name = "github.com/containerd/ttrpc"
revision = "92c8520ef9f86600c650dd540266a007bf03670f"
[[override]]
branch = "master"
name = "github.com/hashicorp/yamux"

View File

@ -18,7 +18,6 @@ package ttrpc
import (
"bufio"
"context"
"encoding/binary"
"io"
"net"
@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
// returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil {
return messageHeader{}, nil, err
@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
return mh, p, nil
}
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
}

View File

@ -29,6 +29,7 @@ import (
"github.com/gogo/protobuf/proto"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
@ -36,28 +37,56 @@ import (
// closed.
var ErrClosed = errors.New("ttrpc: closed")
// Client for a ttrpc server
type Client struct {
codec codec
conn net.Conn
channel *channel
calls chan *callRequest
closed chan struct{}
ctx context.Context
closed func()
closeOnce sync.Once
closeFunc func()
done chan struct{}
userCloseFunc func()
errOnce sync.Once
err error
interceptor UnaryClientInterceptor
}
func NewClient(conn net.Conn) *Client {
// ClientOpts configures a client
type ClientOpts func(c *Client)
// WithOnClose sets the close func whenever the client's Close() method is called
func WithOnClose(onClose func()) ClientOpts {
return func(c *Client) {
c.userCloseFunc = onClose
}
}
// WithUnaryClientInterceptor sets the provided client interceptor
func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
return func(c *Client) {
c.interceptor = i
}
}
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
c := &Client{
codec: codec{},
conn: conn,
channel: newChannel(conn),
calls: make(chan *callRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
closeFunc: func() {},
closed: cancel,
ctx: ctx,
userCloseFunc: func() {},
interceptor: defaultClientInterceptor,
}
for _, o := range opts {
o(c)
}
go c.run()
@ -87,11 +116,18 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
cresp = &Response{}
)
if metadata, ok := GetMetadata(ctx); ok {
metadata.setRequest(creq)
}
if dl, ok := ctx.Deadline(); ok {
creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds()
}
if err := c.dispatch(ctx, creq, cresp); err != nil {
info := &UnaryClientInfo{
FullMethod: fullPath(service, method),
}
if err := c.interceptor(ctx, creq, cresp, info, c.dispatch); err != nil {
return err
}
@ -99,11 +135,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int
return err
}
if cresp.Status == nil {
return errors.New("no status provided on response")
}
if cresp.Status != nil && cresp.Status.Code != int32(codes.OK) {
return status.ErrorProto(cresp.Status)
}
return nil
}
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
@ -119,8 +154,8 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
case <-ctx.Done():
return ctx.Err()
case c.calls <- call:
case <-c.done:
return c.err
case <-c.ctx.Done():
return c.error()
}
select {
@ -128,75 +163,100 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
return ctx.Err()
case err := <-errs:
return filterCloseErr(err)
case <-c.done:
return c.err
case <-c.ctx.Done():
return c.error()
}
}
func (c *Client) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
c.closed()
})
return nil
}
// OnClose allows a close func to be called when the server is closed
func (c *Client) OnClose(closer func()) {
c.closeFunc = closer
}
type message struct {
messageHeader
p []byte
err error
}
type receiver struct {
wg *sync.WaitGroup
messages chan *message
err error
}
func (r *receiver) run(ctx context.Context, c *channel) {
defer r.wg.Done()
for {
select {
case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
r.err = filterCloseErr(err)
return
}
}
select {
case r.messages <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-ctx.Done():
r.err = ctx.Err()
return
}
}
}
}
func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
receiversDone = make(chan struct{})
wg sync.WaitGroup
)
// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
go func() {
defer close(shutdown)
// start one more goroutine to recv messages without blocking.
for {
mh, p, err := c.channel.recv(context.TODO())
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
shutdownErr = err
return
}
}
select {
case incoming <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-c.done:
return
}
}
wg.Wait()
close(receiversDone)
}()
go recv.run(c.ctx, c.channel)
defer c.conn.Close()
defer close(c.done)
defer c.closeFunc()
defer func() {
c.conn.Close()
c.userCloseFunc()
}()
for {
select {
case call := <-calls:
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}
@ -212,41 +272,42 @@ func (c *Client) run() {
call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-shutdown:
if shutdownErr != nil {
shutdownErr = filterCloseErr(shutdownErr)
} else {
shutdownErr = ErrClosed
}
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
case <-receiversDone:
// all the receivers have exited
if recv.err != nil {
c.setError(recv.err)
}
// don't return out, let the close of the context trigger the abort of waiters
c.Close()
return
case <-c.closed:
if c.err == nil {
c.err = ErrClosed
}
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- c.err
}
case <-c.ctx.Done():
abortWaiters(c.error())
return
}
}
}
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
func (c *Client) error() error {
c.errOnce.Do(func() {
if c.err == nil {
c.err = ErrClosed
}
})
return c.err
}
func (c *Client) setError(err error) {
c.errOnce.Do(func() {
c.err = err
})
}
func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg)
if err != nil {
return err
}
return c.channel.send(ctx, streamID, mtype, p)
return c.channel.send(streamID, mtype, p)
}
func (c *Client) recv(resp *Response, msg *message) error {
@ -267,24 +328,23 @@ func (c *Client) recv(resp *Response, msg *message) error {
//
// This purposely ignores errors with a wrapped cause.
func filterCloseErr(err error) error {
if err == nil {
switch {
case err == nil:
return nil
}
if err == io.EOF {
case err == io.EOF:
return ErrClosed
}
if strings.Contains(err.Error(), "use of closed network connection") {
case errors.Cause(err) == io.EOF:
return ErrClosed
}
case strings.Contains(err.Error(), "use of closed network connection"):
return ErrClosed
default:
// if we have an epipe on a write, we cast to errclosed
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
return ErrClosed
}
}
}
return err
}

View File

@ -20,8 +20,10 @@ import "github.com/pkg/errors"
type serverConfig struct {
handshaker Handshaker
interceptor UnaryServerInterceptor
}
// ServerOpt for configuring a ttrpc server
type ServerOpt func(*serverConfig) error
// WithServerHandshaker can be passed to NewServer to ensure that the
@ -37,3 +39,14 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
return nil
}
}
// WithUnaryServerInterceptor sets the provided interceptor on the server
func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
return func(c *serverConfig) error {
if c.interceptor != nil {
return errors.New("only one interceptor allowed per server")
}
c.interceptor = i
return nil
}
}

50
vendor/github.com/containerd/ttrpc/interceptor.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ttrpc
import "context"
// UnaryServerInfo provides information about the server request
type UnaryServerInfo struct {
FullMethod string
}
// UnaryClientInfo provides information about the client request
type UnaryClientInfo struct {
FullMethod string
}
// Unmarshaler contains the server request data and allows it to be unmarshaled
// into a concrete type
type Unmarshaler func(interface{}) error
// Invoker invokes the client's request and response from the ttrpc server
type Invoker func(context.Context, *Request, *Response) error
// UnaryServerInterceptor specifies the interceptor function for server request/response
type UnaryServerInterceptor func(context.Context, Unmarshaler, *UnaryServerInfo, Method) (interface{}, error)
// UnaryClientInterceptor specifies the interceptor function for client request/response
type UnaryClientInterceptor func(context.Context, *Request, *Response, *UnaryClientInfo, Invoker) error
func defaultServerInterceptor(ctx context.Context, unmarshal Unmarshaler, info *UnaryServerInfo, method Method) (interface{}, error) {
return method(ctx, unmarshal)
}
func defaultClientInterceptor(ctx context.Context, req *Request, resp *Response, _ *UnaryClientInfo, invoker Invoker) error {
return invoker(ctx, req, resp)
}

107
vendor/github.com/containerd/ttrpc/metadata.go generated vendored Normal file
View File

@ -0,0 +1,107 @@
/*
Copyright The containerd Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ttrpc
import (
"context"
"strings"
)
// MD is the user type for ttrpc metadata
type MD map[string][]string
// Get returns the metadata for a given key when they exist.
// If there is no metadata, a nil slice and false are returned.
func (m MD) Get(key string) ([]string, bool) {
key = strings.ToLower(key)
list, ok := m[key]
if !ok || len(list) == 0 {
return nil, false
}
return list, true
}
// Set sets the provided values for a given key.
// The values will overwrite any existing values.
// If no values provided, a key will be deleted.
func (m MD) Set(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 {
delete(m, key)
return
}
m[key] = values
}
// Append appends additional values to the given key.
func (m MD) Append(key string, values ...string) {
key = strings.ToLower(key)
if len(values) == 0 {
return
}
current, ok := m[key]
if ok {
m.Set(key, append(current, values...)...)
} else {
m.Set(key, values...)
}
}
func (m MD) setRequest(r *Request) {
for k, values := range m {
for _, v := range values {
r.Metadata = append(r.Metadata, &KeyValue{
Key: k,
Value: v,
})
}
}
}
func (m MD) fromRequest(r *Request) {
for _, kv := range r.Metadata {
m[kv.Key] = append(m[kv.Key], kv.Value)
}
}
type metadataKey struct{}
// GetMetadata retrieves metadata from context.Context (previously attached with WithMetadata)
func GetMetadata(ctx context.Context) (MD, bool) {
metadata, ok := ctx.Value(metadataKey{}).(MD)
return metadata, ok
}
// GetMetadataValue gets a specific metadata value by name from context.Context
func GetMetadataValue(ctx context.Context, name string) (string, bool) {
metadata, ok := GetMetadata(ctx)
if !ok {
return "", false
}
if list, ok := metadata.Get(name); ok {
return list[0], true
}
return "", false
}
// WithMetadata attaches metadata map to a context.Context
func WithMetadata(ctx context.Context, md MD) context.Context {
return context.WithValue(ctx, metadataKey{}, md)
}

View File

@ -53,10 +53,13 @@ func NewServer(opts ...ServerOpt) (*Server, error) {
return nil, err
}
}
if config.interceptor == nil {
config.interceptor = defaultServerInterceptor
}
return &Server{
config: config,
services: newServiceSet(),
services: newServiceSet(config.interceptor),
done: make(chan struct{}),
listeners: make(map[net.Listener]struct{}),
connections: make(map[*serverConn]struct{}),
@ -341,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
default: // proceed
}
mh, p, err := ch.recv(ctx)
mh, p, err := ch.recv()
if err != nil {
status, ok := status.FromError(err)
if !ok {
@ -438,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
return
}
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
return
}
@ -449,7 +452,12 @@ func (c *serverConn) run(sctx context.Context) {
// branch. Basically, it means that we are no longer receiving
// requests due to a terminal error.
recvErr = nil // connection is now "closing"
if err != nil && err != io.EOF {
if err == io.EOF || err == io.ErrUnexpectedEOF {
// The client went away and we should stop processing
// requests, so that the client connection is closed
return
}
if err != nil {
logrus.WithError(err).Error("error receiving message")
}
case <-shutdown:
@ -461,6 +469,12 @@ func (c *serverConn) run(sctx context.Context) {
var noopFunc = func() {}
func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
if len(req.Metadata) > 0 {
md := MD{}
md.fromRequest(req)
ctx = WithMetadata(ctx, md)
}
cancel = noopFunc
if req.TimeoutNano == 0 {
return ctx, cancel

View File

@ -38,11 +38,13 @@ type ServiceDesc struct {
type serviceSet struct {
services map[string]ServiceDesc
interceptor UnaryServerInterceptor
}
func newServiceSet() *serviceSet {
func newServiceSet(interceptor UnaryServerInterceptor) *serviceSet {
return &serviceSet{
services: make(map[string]ServiceDesc),
interceptor: interceptor,
}
}
@ -84,7 +86,11 @@ func (s *serviceSet) dispatch(ctx context.Context, serviceName, methodName strin
return nil
}
resp, err := method(ctx, unmarshal)
info := &UnaryServerInfo{
FullMethod: fullPath(serviceName, methodName),
}
resp, err := s.interceptor(ctx, unmarshal, info, method)
if err != nil {
return nil, err
}
@ -146,5 +152,5 @@ func convertCode(err error) codes.Code {
}
func fullPath(service, method string) string {
return "/" + path.Join("/", service, method)
return "/" + path.Join(service, method)
}

View File

@ -27,6 +27,7 @@ type Request struct {
Method string `protobuf:"bytes,2,opt,name=method,proto3"`
Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"`
TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"`
Metadata []*KeyValue `protobuf:"bytes,5,rep,name=metadata,proto3"`
}
func (r *Request) Reset() { *r = Request{} }
@ -41,3 +42,22 @@ type Response struct {
func (r *Response) Reset() { *r = Response{} }
func (r *Response) String() string { return fmt.Sprintf("%+#v", r) }
func (r *Response) ProtoMessage() {}
type StringList struct {
List []string `protobuf:"bytes,1,rep,name=list,proto3"`
}
func (r *StringList) Reset() { *r = StringList{} }
func (r *StringList) String() string { return fmt.Sprintf("%+#v", r) }
func (r *StringList) ProtoMessage() {}
func makeStringList(item ...string) StringList { return StringList{List: item} }
type KeyValue struct {
Key string `protobuf:"bytes,1,opt,name=key,proto3"`
Value string `protobuf:"bytes,2,opt,name=value,proto3"`
}
func (m *KeyValue) Reset() { *m = KeyValue{} }
func (*KeyValue) ProtoMessage() {}
func (m *KeyValue) String() string { return fmt.Sprintf("%+#v", m) }