Fix DNS suffix search list issue for Windows container and workaround in kube-proxy.

kube-proxy iterates over DNS suffix search list and appends to DNS query for client.
This commit is contained in:
Jiangtian Li
2017-02-11 15:19:40 -08:00
parent 7e2c71f698
commit b9dfb69dd7
9 changed files with 780 additions and 15 deletions

View File

@@ -17,12 +17,14 @@ limitations under the License.
package winuserspace
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/golang/glog"
@@ -30,6 +32,36 @@ import (
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/proxy"
"k8s.io/kubernetes/pkg/util/exec"
"k8s.io/kubernetes/pkg/util/ipconfig"
)
const (
// Kubernetes DNS suffix search list
// TODO: Get DNS suffix search list from docker containers.
// --dns-search option doesn't work on Windows containers and has been
// fixed recently in docker.
// Kubernetes cluster domain
clusterDomain = "cluster.local"
// Kubernetes service domain
serviceDomain = "svc." + clusterDomain
// Kubernetes default namespace domain
namespaceServiceDomain = "default." + serviceDomain
// Kubernetes DNS service port name
dnsPortName = "dns"
// DNS TYPE value A (a host address)
dnsTypeA uint16 = 0x01
// DNS TYPE value AAAA (a host IPv6 address)
dnsTypeAAAA uint16 = 0x1c
// DNS CLASS value IN (the Internet)
dnsClassInternet uint16 = 0x01
)
// Abstraction over TCP/UDP sockets which are proxied.
@@ -205,8 +237,399 @@ func newClientCache() *clientCache {
return &clientCache{clients: map[string]net.Conn{}}
}
// TODO: use Go net dnsmsg library to walk DNS message format
// DNS packet header
type dnsHeader struct {
id uint16
bits uint16
qdCount uint16
anCount uint16
nsCount uint16
arCount uint16
}
// DNS domain name
type dnsDomainName struct {
name string
}
// DNS packet question section
type dnsQuestion struct {
qName dnsDomainName
qType uint16
qClass uint16
}
// DNS message, only interested in question now
type dnsMsg struct {
header dnsHeader
question []dnsQuestion
}
type dnsStruct interface {
walk(f func(field interface{}) (ok bool)) (ok bool)
}
func (header *dnsHeader) walk(f func(field interface{}) bool) bool {
return f(&header.id) &&
f(&header.bits) &&
f(&header.qdCount) &&
f(&header.anCount) &&
f(&header.nsCount) &&
f(&header.arCount)
}
func (question *dnsQuestion) walk(f func(field interface{}) bool) bool {
return f(&question.qName) &&
f(&question.qType) &&
f(&question.qClass)
}
func packDomainName(name string, buffer []byte, index int) (newIndex int, ok bool) {
if name == "" {
buffer[index] = 0
index++
return index, true
}
// one more dot plus trailing 0
if index+len(name)+2 > len(buffer) {
return len(buffer), false
}
domains := strings.Split(name, ".")
for _, domain := range domains {
domainLen := len(domain)
if domainLen == 0 {
return len(buffer), false
}
buffer[index] = byte(domainLen)
index++
copy(buffer[index:index+domainLen], domain)
index += domainLen
}
buffer[index] = 0
index++
return index, true
}
func unpackDomainName(buffer []byte, index int) (name string, newIndex int, ok bool) {
name = ""
for index < len(buffer) {
cnt := int(buffer[index])
index++
if cnt == 0 {
break
}
if index+cnt > len(buffer) {
return "", len(buffer), false
}
if name != "" {
name += "."
}
name += string(buffer[index : index+cnt])
index += cnt
}
if index >= len(buffer) {
return "", len(buffer), false
}
return name, index, true
}
func packStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) {
ok = any.walk(func(field interface{}) bool {
switch value := field.(type) {
case *uint16:
if index+2 > len(buffer) {
return false
}
binary.BigEndian.PutUint16(buffer[index:index+2], *value)
index += 2
return true
case *dnsDomainName:
index, ok = packDomainName((*value).name, buffer, index)
return ok
default:
return false
}
})
if !ok {
return len(buffer), false
}
return index, true
}
func unpackStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) {
ok = any.walk(func(field interface{}) bool {
switch value := field.(type) {
case *uint16:
if index+2 > len(buffer) {
return false
}
*value = binary.BigEndian.Uint16(buffer[index : index+2])
index += 2
return true
case *dnsDomainName:
(*value).name, index, ok = unpackDomainName(buffer, index)
return ok
default:
return false
}
})
if !ok {
return len(buffer), false
}
return index, true
}
// Pack the message structure into buffer
func (msg *dnsMsg) packDnsMsg(buffer []byte) (length int, ok bool) {
index := 0
if index, ok = packStruct(&msg.header, buffer, index); !ok {
return len(buffer), false
}
for i := 0; i < len(msg.question); i++ {
if index, ok = packStruct(&msg.question[i], buffer, index); !ok {
return len(buffer), false
}
}
return index, true
}
// Unpack the buffer into the message structure
func (msg *dnsMsg) unpackDnsMsg(buffer []byte) (ok bool) {
index := 0
if index, ok = unpackStruct(&msg.header, buffer, index); !ok {
return false
}
msg.question = make([]dnsQuestion, msg.header.qdCount)
for i := 0; i < len(msg.question); i++ {
if index, ok = unpackStruct(&msg.question[i], buffer, index); !ok {
return false
}
}
return true
}
// DNS query client classified by address and QTYPE
type dnsClientQuery struct {
clientAddress string
dnsQType uint16
}
// Holds DNS client query, the value contains the index in DNS suffix search list,
// the original DNS message and length for the same client and QTYPE
type dnsClientCache struct {
mu sync.Mutex
clients map[dnsClientQuery]*dnsQueryState
}
type dnsQueryState struct {
searchIndex int32
msg *dnsMsg
}
func newDnsClientCache() *dnsClientCache {
return &dnsClientCache{clients: map[dnsClientQuery]*dnsQueryState{}}
}
func packetRequiresDnsSuffix(dnsType, dnsClass uint16) bool {
return (dnsType == dnsTypeA || dnsType == dnsTypeAAAA) && dnsClass == dnsClassInternet
}
func isDnsService(portName string) bool {
return portName == dnsPortName
}
func appendDnsSuffix(msg *dnsMsg, buffer []byte, length int, dnsSuffix string) int {
if msg == nil || len(msg.question) == 0 {
glog.Warning("DNS message parameter is invalid.")
return length
}
// Save the original name since it will be reused for next iteration
origName := msg.question[0].qName.name
if dnsSuffix != "" {
msg.question[0].qName.name += "." + dnsSuffix
}
len, ok := msg.packDnsMsg(buffer)
msg.question[0].qName.name = origName
if !ok {
glog.Warning("Unable to pack DNS packet.")
return length
}
return len
}
func processUnpackedDnsQueryPacket(dnsClients *dnsClientCache, msg *dnsMsg, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) int {
if dnsSearch == nil || len(dnsSearch) == 0 {
glog.V(1).Infof("DNS search list is not initialized and is empty.")
return length
}
// TODO: handle concurrent queries from a client
dnsClients.mu.Lock()
state, found := dnsClients.clients[dnsClientQuery{host, dnsQType}]
if !found {
state = &dnsQueryState{0, msg}
dnsClients.clients[dnsClientQuery{host, dnsQType}] = state
}
dnsClients.mu.Unlock()
index := atomic.SwapInt32(&state.searchIndex, state.searchIndex+1)
// Also update message ID if the client retries due to previous query time out
state.msg.header.id = msg.header.id
if index < 0 || index >= int32(len(dnsSearch)) {
glog.V(1).Infof("Search index %d is out of range.", index)
return length
}
length = appendDnsSuffix(msg, buffer, length, dnsSearch[index])
return length
}
func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, rcode uint16, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) bool {
var drop bool
if dnsSearch == nil || len(dnsSearch) == 0 {
glog.V(1).Infof("DNS search list is not initialized and is empty.")
return drop
}
dnsClients.mu.Lock()
state, found := dnsClients.clients[dnsClientQuery{host, dnsQType}]
dnsClients.mu.Unlock()
if found {
index := atomic.SwapInt32(&state.searchIndex, state.searchIndex+1)
if rcode != 0 && index >= 0 && index < int32(len(dnsSearch)) {
// If the reponse has failure and iteration through the search list has not
// reached the end, retry on behalf of the client using the original query message
drop = true
length = appendDnsSuffix(state.msg, buffer, length, dnsSearch[index])
_, err := svrConn.Write(buffer[0:length])
if err != nil {
if !logTimeout(err) {
glog.Errorf("Write failed: %v", err)
}
}
} else {
dnsClients.mu.Lock()
delete(dnsClients.clients, dnsClientQuery{host, dnsQType})
dnsClients.mu.Unlock()
}
}
return drop
}
func processDnsQueryPacket(dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) int {
msg := &dnsMsg{}
if !msg.unpackDnsMsg(buffer[:length]) {
glog.Warning("Unable to unpack DNS packet.")
return length
}
// Query - Response bit that specifies whether this message is a query (0) or a response (1).
qr := msg.header.bits & 0x8000
if qr != 0 {
glog.Warning("DNS packet should be a query message.")
return length
}
// QDCOUNT
if msg.header.qdCount != 1 {
glog.V(1).Infof("Number of entries in the question section of the DNS packet is: %d", msg.header.qdCount)
glog.V(1).Infof("DNS suffix appending does not support more than one question.")
return length
}
// ANCOUNT, NSCOUNT, ARCOUNT
if msg.header.anCount != 0 || msg.header.nsCount != 0 || msg.header.arCount != 0 {
glog.V(1).Infof("DNS packet contains more than question section.")
return length
}
dnsQType := msg.question[0].qType
dnsQClass := msg.question[0].qClass
if packetRequiresDnsSuffix(dnsQType, dnsQClass) {
host, _, err := net.SplitHostPort(cliAddr.String())
if err != nil {
glog.V(1).Infof("Failed to get host from client address: %v", err)
host = cliAddr.String()
}
length = processUnpackedDnsQueryPacket(dnsClients, msg, host, dnsQType, buffer, length, dnsSearch)
}
return length
}
func processDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) bool {
var drop bool
msg := &dnsMsg{}
if !msg.unpackDnsMsg(buffer[:length]) {
glog.Warning("Unable to unpack DNS packet.")
return drop
}
// Query - Response bit that specifies whether this message is a query (0) or a response (1).
qr := msg.header.bits & 0x8000
if qr == 0 {
glog.Warning("DNS packet should be a response message.")
return drop
}
// QDCOUNT
if msg.header.qdCount != 1 {
glog.V(1).Infof("Number of entries in the reponse section of the DNS packet is: %d", msg.header.qdCount)
return drop
}
dnsQType := msg.question[0].qType
dnsQClass := msg.question[0].qClass
if packetRequiresDnsSuffix(dnsQType, dnsQClass) {
host, _, err := net.SplitHostPort(cliAddr.String())
if err != nil {
glog.V(1).Infof("Failed to get host from client address: %v", err)
host = cliAddr.String()
}
rcode := msg.header.bits & 0xf
drop = processUnpackedDnsResponsePacket(svrConn, dnsClients, rcode, host, dnsQType, buffer, length, dnsSearch)
}
return drop
}
func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serviceInfo, proxier *Proxier) {
var buffer [4096]byte // 4KiB should be enough for most whole-packets
var dnsSearch []string
if isDnsService(service.Port) {
dnsSearch = []string{"", namespaceServiceDomain, serviceDomain, clusterDomain}
execer := exec.New()
ipconfigInterface := ipconfig.New(execer)
suffixList, err := ipconfigInterface.GetDnsSuffixSearchList()
if err == nil {
for _, suffix := range suffixList {
dnsSearch = append(dnsSearch, suffix)
}
}
}
for {
if !myInfo.isAlive() {
// The service port was closed or replaced.
@@ -226,8 +649,14 @@ func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serv
glog.Errorf("ReadFrom failed, exiting ProxyLoop: %v", err)
break
}
// If this is DNS query packet
if isDnsService(service.Port) {
n = processDnsQueryPacket(myInfo.dnsClients, cliAddr, buffer[:], n, dnsSearch)
}
// If this is a client we know already, reuse the connection and goroutine.
svrConn, err := udp.getBackendConn(myInfo.activeClients, cliAddr, proxier, service, myInfo.timeout)
svrConn, err := udp.getBackendConn(myInfo.activeClients, myInfo.dnsClients, cliAddr, proxier, service, myInfo.timeout, dnsSearch)
if err != nil {
continue
}
@@ -249,7 +678,7 @@ func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serv
}
}
func (udp *udpProxySocket) getBackendConn(activeClients *clientCache, cliAddr net.Addr, proxier *Proxier, service ServicePortPortalName, timeout time.Duration) (net.Conn, error) {
func (udp *udpProxySocket) getBackendConn(activeClients *clientCache, dnsClients *dnsClientCache, cliAddr net.Addr, proxier *Proxier, service ServicePortPortalName, timeout time.Duration, dnsSearch []string) (net.Conn, error) {
activeClients.mu.Lock()
defer activeClients.mu.Unlock()
@@ -268,17 +697,17 @@ func (udp *udpProxySocket) getBackendConn(activeClients *clientCache, cliAddr ne
return nil, err
}
activeClients.clients[cliAddr.String()] = svrConn
go func(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, timeout time.Duration) {
go func(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, dnsClients *dnsClientCache, service ServicePortPortalName, timeout time.Duration, dnsSearch []string) {
defer runtime.HandleCrash()
udp.proxyClient(cliAddr, svrConn, activeClients, timeout)
}(cliAddr, svrConn, activeClients, timeout)
udp.proxyClient(cliAddr, svrConn, activeClients, dnsClients, service, timeout, dnsSearch)
}(cliAddr, svrConn, activeClients, dnsClients, service, timeout, dnsSearch)
}
return svrConn, nil
}
// This function is expected to be called as a goroutine.
// TODO: Track and log bytes copied, like TCP
func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, timeout time.Duration) {
func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activeClients *clientCache, dnsClients *dnsClientCache, service ServicePortPortalName, timeout time.Duration, dnsSearch []string) {
defer svrConn.Close()
var buffer [4096]byte
for {
@@ -289,17 +718,25 @@ func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activ
}
break
}
err = svrConn.SetDeadline(time.Now().Add(timeout))
if err != nil {
glog.Errorf("SetDeadline failed: %v", err)
break
drop := false
if isDnsService(service.Port) {
drop = processDnsResponsePacket(svrConn, dnsClients, cliAddr, buffer[:], n, dnsSearch)
}
n, err = udp.WriteTo(buffer[0:n], cliAddr)
if err != nil {
if !logTimeout(err) {
glog.Errorf("WriteTo failed: %v", err)
if !drop {
err = svrConn.SetDeadline(time.Now().Add(timeout))
if err != nil {
glog.Errorf("SetDeadline failed: %v", err)
break
}
n, err = udp.WriteTo(buffer[0:n], cliAddr)
if err != nil {
if !logTimeout(err) {
glog.Errorf("WriteTo failed: %v", err)
}
break
}
break
}
}
activeClients.mu.Lock()