mirror of
				https://github.com/k3s-io/kubernetes.git
				synced 2025-10-30 21:30:16 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			733 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			733 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // DNS server implementation.
 | |
| 
 | |
| package dns
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/tls"
 | |
| 	"encoding/binary"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| // Maximum number of TCP queries before we close the socket.
 | |
| const maxTCPQueries = 128
 | |
| 
 | |
| // Handler is implemented by any value that implements ServeDNS.
 | |
| type Handler interface {
 | |
| 	ServeDNS(w ResponseWriter, r *Msg)
 | |
| }
 | |
| 
 | |
| // A ResponseWriter interface is used by an DNS handler to
 | |
| // construct an DNS response.
 | |
| type ResponseWriter interface {
 | |
| 	// LocalAddr returns the net.Addr of the server
 | |
| 	LocalAddr() net.Addr
 | |
| 	// RemoteAddr returns the net.Addr of the client that sent the current request.
 | |
| 	RemoteAddr() net.Addr
 | |
| 	// WriteMsg writes a reply back to the client.
 | |
| 	WriteMsg(*Msg) error
 | |
| 	// Write writes a raw buffer back to the client.
 | |
| 	Write([]byte) (int, error)
 | |
| 	// Close closes the connection.
 | |
| 	Close() error
 | |
| 	// TsigStatus returns the status of the Tsig.
 | |
| 	TsigStatus() error
 | |
| 	// TsigTimersOnly sets the tsig timers only boolean.
 | |
| 	TsigTimersOnly(bool)
 | |
| 	// Hijack lets the caller take over the connection.
 | |
| 	// After a call to Hijack(), the DNS package will not do anything with the connection.
 | |
| 	Hijack()
 | |
| }
 | |
| 
 | |
| type response struct {
 | |
| 	hijacked       bool // connection has been hijacked by handler
 | |
| 	tsigStatus     error
 | |
| 	tsigTimersOnly bool
 | |
| 	tsigRequestMAC string
 | |
| 	tsigSecret     map[string]string // the tsig secrets
 | |
| 	udp            *net.UDPConn      // i/o connection if UDP was used
 | |
| 	tcp            net.Conn          // i/o connection if TCP was used
 | |
| 	udpSession     *SessionUDP       // oob data to get egress interface right
 | |
| 	remoteAddr     net.Addr          // address of the client
 | |
| 	writer         Writer            // writer to output the raw DNS bits
 | |
| }
 | |
| 
 | |
| // ServeMux is an DNS request multiplexer. It matches the
 | |
| // zone name of each incoming request against a list of
 | |
| // registered patterns add calls the handler for the pattern
 | |
| // that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
 | |
| // that queries for the DS record are redirected to the parent zone (if that
 | |
| // is also registered), otherwise the child gets the query.
 | |
| // ServeMux is also safe for concurrent access from multiple goroutines.
 | |
| type ServeMux struct {
 | |
| 	z map[string]Handler
 | |
| 	m *sync.RWMutex
 | |
| }
 | |
| 
 | |
| // NewServeMux allocates and returns a new ServeMux.
 | |
| func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
 | |
| 
 | |
| // DefaultServeMux is the default ServeMux used by Serve.
 | |
| var DefaultServeMux = NewServeMux()
 | |
| 
 | |
| // The HandlerFunc type is an adapter to allow the use of
 | |
| // ordinary functions as DNS handlers.  If f is a function
 | |
| // with the appropriate signature, HandlerFunc(f) is a
 | |
| // Handler object that calls f.
 | |
| type HandlerFunc func(ResponseWriter, *Msg)
 | |
| 
 | |
| // ServeDNS calls f(w, r).
 | |
| func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
 | |
| 	f(w, r)
 | |
| }
 | |
| 
 | |
| // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
 | |
| func HandleFailed(w ResponseWriter, r *Msg) {
 | |
| 	m := new(Msg)
 | |
| 	m.SetRcode(r, RcodeServerFailure)
 | |
| 	// does not matter if this write fails
 | |
| 	w.WriteMsg(m)
 | |
| }
 | |
| 
 | |
| func failedHandler() Handler { return HandlerFunc(HandleFailed) }
 | |
| 
 | |
| // ListenAndServe Starts a server on address and network specified Invoke handler
 | |
| // for incoming queries.
 | |
| func ListenAndServe(addr string, network string, handler Handler) error {
 | |
| 	server := &Server{Addr: addr, Net: network, Handler: handler}
 | |
| 	return server.ListenAndServe()
 | |
| }
 | |
| 
 | |
| // ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
 | |
| // http://golang.org/pkg/net/http/#ListenAndServeTLS
 | |
| func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
 | |
| 	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	config := tls.Config{
 | |
| 		Certificates: []tls.Certificate{cert},
 | |
| 	}
 | |
| 
 | |
| 	server := &Server{
 | |
| 		Addr:      addr,
 | |
| 		Net:       "tcp-tls",
 | |
| 		TLSConfig: &config,
 | |
| 		Handler:   handler,
 | |
| 	}
 | |
| 
 | |
| 	return server.ListenAndServe()
 | |
| }
 | |
| 
 | |
| // ActivateAndServe activates a server with a listener from systemd,
 | |
| // l and p should not both be non-nil.
 | |
| // If both l and p are not nil only p will be used.
 | |
| // Invoke handler for incoming queries.
 | |
| func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
 | |
| 	server := &Server{Listener: l, PacketConn: p, Handler: handler}
 | |
| 	return server.ActivateAndServe()
 | |
| }
 | |
| 
 | |
| func (mux *ServeMux) match(q string, t uint16) Handler {
 | |
| 	mux.m.RLock()
 | |
| 	defer mux.m.RUnlock()
 | |
| 	var handler Handler
 | |
| 	b := make([]byte, len(q)) // worst case, one label of length q
 | |
| 	off := 0
 | |
| 	end := false
 | |
| 	for {
 | |
| 		l := len(q[off:])
 | |
| 		for i := 0; i < l; i++ {
 | |
| 			b[i] = q[off+i]
 | |
| 			if b[i] >= 'A' && b[i] <= 'Z' {
 | |
| 				b[i] |= ('a' - 'A')
 | |
| 			}
 | |
| 		}
 | |
| 		if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key
 | |
| 			if t != TypeDS {
 | |
| 				return h
 | |
| 			}
 | |
| 			// Continue for DS to see if we have a parent too, if so delegeate to the parent
 | |
| 			handler = h
 | |
| 		}
 | |
| 		off, end = NextLabel(q, off)
 | |
| 		if end {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	// Wildcard match, if we have found nothing try the root zone as a last resort.
 | |
| 	if h, ok := mux.z["."]; ok {
 | |
| 		return h
 | |
| 	}
 | |
| 	return handler
 | |
| }
 | |
| 
 | |
| // Handle adds a handler to the ServeMux for pattern.
 | |
| func (mux *ServeMux) Handle(pattern string, handler Handler) {
 | |
| 	if pattern == "" {
 | |
| 		panic("dns: invalid pattern " + pattern)
 | |
| 	}
 | |
| 	mux.m.Lock()
 | |
| 	mux.z[Fqdn(pattern)] = handler
 | |
| 	mux.m.Unlock()
 | |
| }
 | |
| 
 | |
| // HandleFunc adds a handler function to the ServeMux for pattern.
 | |
| func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
 | |
| 	mux.Handle(pattern, HandlerFunc(handler))
 | |
| }
 | |
| 
 | |
| // HandleRemove deregistrars the handler specific for pattern from the ServeMux.
 | |
| func (mux *ServeMux) HandleRemove(pattern string) {
 | |
| 	if pattern == "" {
 | |
| 		panic("dns: invalid pattern " + pattern)
 | |
| 	}
 | |
| 	mux.m.Lock()
 | |
| 	delete(mux.z, Fqdn(pattern))
 | |
| 	mux.m.Unlock()
 | |
| }
 | |
| 
 | |
| // ServeDNS dispatches the request to the handler whose
 | |
| // pattern most closely matches the request message. If DefaultServeMux
 | |
| // is used the correct thing for DS queries is done: a possible parent
 | |
| // is sought.
 | |
| // If no handler is found a standard SERVFAIL message is returned
 | |
| // If the request message does not have exactly one question in the
 | |
| // question section a SERVFAIL is returned, unlesss Unsafe is true.
 | |
| func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
 | |
| 	var h Handler
 | |
| 	if len(request.Question) < 1 { // allow more than one question
 | |
| 		h = failedHandler()
 | |
| 	} else {
 | |
| 		if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
 | |
| 			h = failedHandler()
 | |
| 		}
 | |
| 	}
 | |
| 	h.ServeDNS(w, request)
 | |
| }
 | |
| 
 | |
| // Handle registers the handler with the given pattern
 | |
| // in the DefaultServeMux. The documentation for
 | |
| // ServeMux explains how patterns are matched.
 | |
| func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
 | |
| 
 | |
| // HandleRemove deregisters the handle with the given pattern
 | |
| // in the DefaultServeMux.
 | |
| func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
 | |
| 
 | |
| // HandleFunc registers the handler function with the given pattern
 | |
| // in the DefaultServeMux.
 | |
| func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
 | |
| 	DefaultServeMux.HandleFunc(pattern, handler)
 | |
| }
 | |
| 
 | |
| // Writer writes raw DNS messages; each call to Write should send an entire message.
 | |
| type Writer interface {
 | |
| 	io.Writer
 | |
| }
 | |
| 
 | |
| // Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
 | |
| type Reader interface {
 | |
| 	// ReadTCP reads a raw message from a TCP connection. Implementations may alter
 | |
| 	// connection properties, for example the read-deadline.
 | |
| 	ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
 | |
| 	// ReadUDP reads a raw message from a UDP connection. Implementations may alter
 | |
| 	// connection properties, for example the read-deadline.
 | |
| 	ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
 | |
| }
 | |
| 
 | |
| // defaultReader is an adapter for the Server struct that implements the Reader interface
 | |
| // using the readTCP and readUDP func of the embedded Server.
 | |
| type defaultReader struct {
 | |
| 	*Server
 | |
| }
 | |
| 
 | |
| func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
 | |
| 	return dr.readTCP(conn, timeout)
 | |
| }
 | |
| 
 | |
| func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
 | |
| 	return dr.readUDP(conn, timeout)
 | |
| }
 | |
| 
 | |
| // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
 | |
| // Implementations should never return a nil Reader.
 | |
| type DecorateReader func(Reader) Reader
 | |
| 
 | |
| // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
 | |
| // Implementations should never return a nil Writer.
 | |
| type DecorateWriter func(Writer) Writer
 | |
| 
 | |
| // A Server defines parameters for running an DNS server.
 | |
| type Server struct {
 | |
| 	// Address to listen on, ":dns" if empty.
 | |
| 	Addr string
 | |
| 	// if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
 | |
| 	Net string
 | |
| 	// TCP Listener to use, this is to aid in systemd's socket activation.
 | |
| 	Listener net.Listener
 | |
| 	// TLS connection configuration
 | |
| 	TLSConfig *tls.Config
 | |
| 	// UDP "Listener" to use, this is to aid in systemd's socket activation.
 | |
| 	PacketConn net.PacketConn
 | |
| 	// Handler to invoke, dns.DefaultServeMux if nil.
 | |
| 	Handler Handler
 | |
| 	// Default buffer size to use to read incoming UDP messages. If not set
 | |
| 	// it defaults to MinMsgSize (512 B).
 | |
| 	UDPSize int
 | |
| 	// The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
 | |
| 	ReadTimeout time.Duration
 | |
| 	// The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
 | |
| 	WriteTimeout time.Duration
 | |
| 	// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
 | |
| 	IdleTimeout func() time.Duration
 | |
| 	// Secret(s) for Tsig map[<zonename>]<base64 secret>.
 | |
| 	TsigSecret map[string]string
 | |
| 	// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
 | |
| 	// the handler. It will specifically not check if the query has the QR bit not set.
 | |
| 	Unsafe bool
 | |
| 	// If NotifyStartedFunc is set it is called once the server has started listening.
 | |
| 	NotifyStartedFunc func()
 | |
| 	// DecorateReader is optional, allows customization of the process that reads raw DNS messages.
 | |
| 	DecorateReader DecorateReader
 | |
| 	// DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
 | |
| 	DecorateWriter DecorateWriter
 | |
| 
 | |
| 	// Graceful shutdown handling
 | |
| 
 | |
| 	inFlight sync.WaitGroup
 | |
| 
 | |
| 	lock    sync.RWMutex
 | |
| 	started bool
 | |
| }
 | |
| 
 | |
| // ListenAndServe starts a nameserver on the configured address in *Server.
 | |
| func (srv *Server) ListenAndServe() error {
 | |
| 	srv.lock.Lock()
 | |
| 	defer srv.lock.Unlock()
 | |
| 	if srv.started {
 | |
| 		return &Error{err: "server already started"}
 | |
| 	}
 | |
| 	addr := srv.Addr
 | |
| 	if addr == "" {
 | |
| 		addr = ":domain"
 | |
| 	}
 | |
| 	if srv.UDPSize == 0 {
 | |
| 		srv.UDPSize = MinMsgSize
 | |
| 	}
 | |
| 	switch srv.Net {
 | |
| 	case "tcp", "tcp4", "tcp6":
 | |
| 		a, err := net.ResolveTCPAddr(srv.Net, addr)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		l, err := net.ListenTCP(srv.Net, a)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		srv.Listener = l
 | |
| 		srv.started = true
 | |
| 		srv.lock.Unlock()
 | |
| 		err = srv.serveTCP(l)
 | |
| 		srv.lock.Lock() // to satisfy the defer at the top
 | |
| 		return err
 | |
| 	case "tcp-tls", "tcp4-tls", "tcp6-tls":
 | |
| 		network := "tcp"
 | |
| 		if srv.Net == "tcp4-tls" {
 | |
| 			network = "tcp4"
 | |
| 		} else if srv.Net == "tcp6" {
 | |
| 			network = "tcp6"
 | |
| 		}
 | |
| 
 | |
| 		l, err := tls.Listen(network, addr, srv.TLSConfig)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		srv.Listener = l
 | |
| 		srv.started = true
 | |
| 		srv.lock.Unlock()
 | |
| 		err = srv.serveTCP(l)
 | |
| 		srv.lock.Lock() // to satisfy the defer at the top
 | |
| 		return err
 | |
| 	case "udp", "udp4", "udp6":
 | |
| 		a, err := net.ResolveUDPAddr(srv.Net, addr)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		l, err := net.ListenUDP(srv.Net, a)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		if e := setUDPSocketOptions(l); e != nil {
 | |
| 			return e
 | |
| 		}
 | |
| 		srv.PacketConn = l
 | |
| 		srv.started = true
 | |
| 		srv.lock.Unlock()
 | |
| 		err = srv.serveUDP(l)
 | |
| 		srv.lock.Lock() // to satisfy the defer at the top
 | |
| 		return err
 | |
| 	}
 | |
| 	return &Error{err: "bad network"}
 | |
| }
 | |
| 
 | |
| // ActivateAndServe starts a nameserver with the PacketConn or Listener
 | |
| // configured in *Server. Its main use is to start a server from systemd.
 | |
| func (srv *Server) ActivateAndServe() error {
 | |
| 	srv.lock.Lock()
 | |
| 	defer srv.lock.Unlock()
 | |
| 	if srv.started {
 | |
| 		return &Error{err: "server already started"}
 | |
| 	}
 | |
| 	pConn := srv.PacketConn
 | |
| 	l := srv.Listener
 | |
| 	if pConn != nil {
 | |
| 		if srv.UDPSize == 0 {
 | |
| 			srv.UDPSize = MinMsgSize
 | |
| 		}
 | |
| 		if t, ok := pConn.(*net.UDPConn); ok {
 | |
| 			if e := setUDPSocketOptions(t); e != nil {
 | |
| 				return e
 | |
| 			}
 | |
| 			srv.started = true
 | |
| 			srv.lock.Unlock()
 | |
| 			e := srv.serveUDP(t)
 | |
| 			srv.lock.Lock() // to satisfy the defer at the top
 | |
| 			return e
 | |
| 		}
 | |
| 	}
 | |
| 	if l != nil {
 | |
| 		srv.started = true
 | |
| 		srv.lock.Unlock()
 | |
| 		e := srv.serveTCP(l)
 | |
| 		srv.lock.Lock() // to satisfy the defer at the top
 | |
| 		return e
 | |
| 	}
 | |
| 	return &Error{err: "bad listeners"}
 | |
| }
 | |
| 
 | |
| // Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
 | |
| // ActivateAndServe will return. All in progress queries are completed before the server
 | |
| // is taken down. If the Shutdown is taking longer than the reading timeout an error
 | |
| // is returned.
 | |
| func (srv *Server) Shutdown() error {
 | |
| 	srv.lock.Lock()
 | |
| 	if !srv.started {
 | |
| 		srv.lock.Unlock()
 | |
| 		return &Error{err: "server not started"}
 | |
| 	}
 | |
| 	srv.started = false
 | |
| 	srv.lock.Unlock()
 | |
| 
 | |
| 	if srv.PacketConn != nil {
 | |
| 		srv.PacketConn.Close()
 | |
| 	}
 | |
| 	if srv.Listener != nil {
 | |
| 		srv.Listener.Close()
 | |
| 	}
 | |
| 
 | |
| 	fin := make(chan bool)
 | |
| 	go func() {
 | |
| 		srv.inFlight.Wait()
 | |
| 		fin <- true
 | |
| 	}()
 | |
| 
 | |
| 	select {
 | |
| 	case <-time.After(srv.getReadTimeout()):
 | |
| 		return &Error{err: "server shutdown is pending"}
 | |
| 	case <-fin:
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
 | |
| func (srv *Server) getReadTimeout() time.Duration {
 | |
| 	rtimeout := dnsTimeout
 | |
| 	if srv.ReadTimeout != 0 {
 | |
| 		rtimeout = srv.ReadTimeout
 | |
| 	}
 | |
| 	return rtimeout
 | |
| }
 | |
| 
 | |
| // serveTCP starts a TCP listener for the server.
 | |
| // Each request is handled in a separate goroutine.
 | |
| func (srv *Server) serveTCP(l net.Listener) error {
 | |
| 	defer l.Close()
 | |
| 
 | |
| 	if srv.NotifyStartedFunc != nil {
 | |
| 		srv.NotifyStartedFunc()
 | |
| 	}
 | |
| 
 | |
| 	reader := Reader(&defaultReader{srv})
 | |
| 	if srv.DecorateReader != nil {
 | |
| 		reader = srv.DecorateReader(reader)
 | |
| 	}
 | |
| 
 | |
| 	handler := srv.Handler
 | |
| 	if handler == nil {
 | |
| 		handler = DefaultServeMux
 | |
| 	}
 | |
| 	rtimeout := srv.getReadTimeout()
 | |
| 	// deadline is not used here
 | |
| 	for {
 | |
| 		rw, err := l.Accept()
 | |
| 		if err != nil {
 | |
| 			if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
 | |
| 				continue
 | |
| 			}
 | |
| 			return err
 | |
| 		}
 | |
| 		m, err := reader.ReadTCP(rw, rtimeout)
 | |
| 		srv.lock.RLock()
 | |
| 		if !srv.started {
 | |
| 			srv.lock.RUnlock()
 | |
| 			return nil
 | |
| 		}
 | |
| 		srv.lock.RUnlock()
 | |
| 		if err != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		srv.inFlight.Add(1)
 | |
| 		go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // serveUDP starts a UDP listener for the server.
 | |
| // Each request is handled in a separate goroutine.
 | |
| func (srv *Server) serveUDP(l *net.UDPConn) error {
 | |
| 	defer l.Close()
 | |
| 
 | |
| 	if srv.NotifyStartedFunc != nil {
 | |
| 		srv.NotifyStartedFunc()
 | |
| 	}
 | |
| 
 | |
| 	reader := Reader(&defaultReader{srv})
 | |
| 	if srv.DecorateReader != nil {
 | |
| 		reader = srv.DecorateReader(reader)
 | |
| 	}
 | |
| 
 | |
| 	handler := srv.Handler
 | |
| 	if handler == nil {
 | |
| 		handler = DefaultServeMux
 | |
| 	}
 | |
| 	rtimeout := srv.getReadTimeout()
 | |
| 	// deadline is not used here
 | |
| 	for {
 | |
| 		m, s, err := reader.ReadUDP(l, rtimeout)
 | |
| 		srv.lock.RLock()
 | |
| 		if !srv.started {
 | |
| 			srv.lock.RUnlock()
 | |
| 			return nil
 | |
| 		}
 | |
| 		srv.lock.RUnlock()
 | |
| 		if err != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		srv.inFlight.Add(1)
 | |
| 		go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Serve a new connection.
 | |
| func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
 | |
| 	defer srv.inFlight.Done()
 | |
| 
 | |
| 	w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
 | |
| 	if srv.DecorateWriter != nil {
 | |
| 		w.writer = srv.DecorateWriter(w)
 | |
| 	} else {
 | |
| 		w.writer = w
 | |
| 	}
 | |
| 
 | |
| 	q := 0 // counter for the amount of TCP queries we get
 | |
| 
 | |
| 	reader := Reader(&defaultReader{srv})
 | |
| 	if srv.DecorateReader != nil {
 | |
| 		reader = srv.DecorateReader(reader)
 | |
| 	}
 | |
| Redo:
 | |
| 	req := new(Msg)
 | |
| 	err := req.Unpack(m)
 | |
| 	if err != nil { // Send a FormatError back
 | |
| 		x := new(Msg)
 | |
| 		x.SetRcodeFormatError(req)
 | |
| 		w.WriteMsg(x)
 | |
| 		goto Exit
 | |
| 	}
 | |
| 	if !srv.Unsafe && req.Response {
 | |
| 		goto Exit
 | |
| 	}
 | |
| 
 | |
| 	w.tsigStatus = nil
 | |
| 	if w.tsigSecret != nil {
 | |
| 		if t := req.IsTsig(); t != nil {
 | |
| 			secret := t.Hdr.Name
 | |
| 			if _, ok := w.tsigSecret[secret]; !ok {
 | |
| 				w.tsigStatus = ErrKeyAlg
 | |
| 			}
 | |
| 			w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
 | |
| 			w.tsigTimersOnly = false
 | |
| 			w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
 | |
| 		}
 | |
| 	}
 | |
| 	h.ServeDNS(w, req) // Writes back to the client
 | |
| 
 | |
| Exit:
 | |
| 	if w.tcp == nil {
 | |
| 		return
 | |
| 	}
 | |
| 	// TODO(miek): make this number configurable?
 | |
| 	if q > maxTCPQueries { // close socket after this many queries
 | |
| 		w.Close()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if w.hijacked {
 | |
| 		return // client calls Close()
 | |
| 	}
 | |
| 	if u != nil { // UDP, "close" and return
 | |
| 		w.Close()
 | |
| 		return
 | |
| 	}
 | |
| 	idleTimeout := tcpIdleTimeout
 | |
| 	if srv.IdleTimeout != nil {
 | |
| 		idleTimeout = srv.IdleTimeout()
 | |
| 	}
 | |
| 	m, err = reader.ReadTCP(w.tcp, idleTimeout)
 | |
| 	if err == nil {
 | |
| 		q++
 | |
| 		goto Redo
 | |
| 	}
 | |
| 	w.Close()
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
 | |
| 	conn.SetReadDeadline(time.Now().Add(timeout))
 | |
| 	l := make([]byte, 2)
 | |
| 	n, err := conn.Read(l)
 | |
| 	if err != nil || n != 2 {
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		return nil, ErrShortRead
 | |
| 	}
 | |
| 	length := binary.BigEndian.Uint16(l)
 | |
| 	if length == 0 {
 | |
| 		return nil, ErrShortRead
 | |
| 	}
 | |
| 	m := make([]byte, int(length))
 | |
| 	n, err = conn.Read(m[:int(length)])
 | |
| 	if err != nil || n == 0 {
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		return nil, ErrShortRead
 | |
| 	}
 | |
| 	i := n
 | |
| 	for i < int(length) {
 | |
| 		j, err := conn.Read(m[i:int(length)])
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		i += j
 | |
| 	}
 | |
| 	n = i
 | |
| 	m = m[:n]
 | |
| 	return m, nil
 | |
| }
 | |
| 
 | |
| func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
 | |
| 	conn.SetReadDeadline(time.Now().Add(timeout))
 | |
| 	m := make([]byte, srv.UDPSize)
 | |
| 	n, s, err := ReadFromSessionUDP(conn, m)
 | |
| 	if err != nil || n == 0 {
 | |
| 		if err != nil {
 | |
| 			return nil, nil, err
 | |
| 		}
 | |
| 		return nil, nil, ErrShortRead
 | |
| 	}
 | |
| 	m = m[:n]
 | |
| 	return m, s, nil
 | |
| }
 | |
| 
 | |
| // WriteMsg implements the ResponseWriter.WriteMsg method.
 | |
| func (w *response) WriteMsg(m *Msg) (err error) {
 | |
| 	var data []byte
 | |
| 	if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
 | |
| 		if t := m.IsTsig(); t != nil {
 | |
| 			data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
 | |
| 			if err != nil {
 | |
| 				return err
 | |
| 			}
 | |
| 			_, err = w.writer.Write(data)
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	data, err = m.Pack()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	_, err = w.writer.Write(data)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // Write implements the ResponseWriter.Write method.
 | |
| func (w *response) Write(m []byte) (int, error) {
 | |
| 	switch {
 | |
| 	case w.udp != nil:
 | |
| 		n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
 | |
| 		return n, err
 | |
| 	case w.tcp != nil:
 | |
| 		lm := len(m)
 | |
| 		if lm < 2 {
 | |
| 			return 0, io.ErrShortBuffer
 | |
| 		}
 | |
| 		if lm > MaxMsgSize {
 | |
| 			return 0, &Error{err: "message too large"}
 | |
| 		}
 | |
| 		l := make([]byte, 2, 2+lm)
 | |
| 		binary.BigEndian.PutUint16(l, uint16(lm))
 | |
| 		m = append(l, m...)
 | |
| 
 | |
| 		n, err := io.Copy(w.tcp, bytes.NewReader(m))
 | |
| 		return int(n), err
 | |
| 	}
 | |
| 	panic("not reached")
 | |
| }
 | |
| 
 | |
| // LocalAddr implements the ResponseWriter.LocalAddr method.
 | |
| func (w *response) LocalAddr() net.Addr {
 | |
| 	if w.tcp != nil {
 | |
| 		return w.tcp.LocalAddr()
 | |
| 	}
 | |
| 	return w.udp.LocalAddr()
 | |
| }
 | |
| 
 | |
| // RemoteAddr implements the ResponseWriter.RemoteAddr method.
 | |
| func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
 | |
| 
 | |
| // TsigStatus implements the ResponseWriter.TsigStatus method.
 | |
| func (w *response) TsigStatus() error { return w.tsigStatus }
 | |
| 
 | |
| // TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
 | |
| func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
 | |
| 
 | |
| // Hijack implements the ResponseWriter.Hijack method.
 | |
| func (w *response) Hijack() { w.hijacked = true }
 | |
| 
 | |
| // Close implements the ResponseWriter.Close method
 | |
| func (w *response) Close() error {
 | |
| 	// Can't close the udp conn, as that is actually the listener.
 | |
| 	if w.tcp != nil {
 | |
| 		e := w.tcp.Close()
 | |
| 		w.tcp = nil
 | |
| 		return e
 | |
| 	}
 | |
| 	return nil
 | |
| }
 |