diff --git a/pkg/proxy/winuserspace/BUILD b/pkg/proxy/winuserspace/BUILD index 6e2d5e6ec4f..1924c023c5d 100644 --- a/pkg/proxy/winuserspace/BUILD +++ b/pkg/proxy/winuserspace/BUILD @@ -38,7 +38,6 @@ go_test( name = "go_default_test", srcs = [ "proxier_test.go", - "proxysocket_test.go", "roundrobin_test.go", ], library = ":go_default_library", diff --git a/pkg/proxy/winuserspace/proxier.go b/pkg/proxy/winuserspace/proxier.go index 8bce3d639f9..74f9aa35ce2 100644 --- a/pkg/proxy/winuserspace/proxier.go +++ b/pkg/proxy/winuserspace/proxier.go @@ -262,7 +262,7 @@ func (proxier *Proxier) addServicePortPortal(servicePortPortalName ServicePortPo socket: sock, timeout: timeout, activeClients: newClientCache(), - dnsClients: newDnsClientCache(), + dnsClients: newDNSClientCache(), sessionAffinityType: api.ServiceAffinityNone, // default } proxier.setServiceInfo(servicePortPortalName, si) diff --git a/pkg/proxy/winuserspace/proxysocket.go b/pkg/proxy/winuserspace/proxysocket.go index 2f0b361dc4f..9559eaa5af6 100644 --- a/pkg/proxy/winuserspace/proxysocket.go +++ b/pkg/proxy/winuserspace/proxysocket.go @@ -17,7 +17,6 @@ limitations under the License. package winuserspace import ( - "encoding/binary" "fmt" "io" "net" @@ -28,6 +27,7 @@ import ( "time" "github.com/golang/glog" + "github.com/miekg/dns" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/kubernetes/pkg/api" @@ -237,190 +237,6 @@ func newClientCache() *clientCache { return &clientCache{clients: map[string]net.Conn{}} } -// TODO: use Go net dnsmsg library to walk DNS message format -// DNS packet header -type dnsHeader struct { - id uint16 - bits uint16 - qdCount uint16 - anCount uint16 - nsCount uint16 - arCount uint16 -} - -// DNS domain name -type dnsDomainName struct { - name string -} - -// DNS packet question section -type dnsQuestion struct { - qName dnsDomainName - qType uint16 - qClass uint16 -} - -// DNS message, only interested in question now -type dnsMsg struct { - header dnsHeader - question []dnsQuestion -} - -type dnsStruct interface { - walk(f func(field interface{}) (ok bool)) (ok bool) -} - -func (header *dnsHeader) walk(f func(field interface{}) bool) bool { - return f(&header.id) && - f(&header.bits) && - f(&header.qdCount) && - f(&header.anCount) && - f(&header.nsCount) && - f(&header.arCount) -} - -func (question *dnsQuestion) walk(f func(field interface{}) bool) bool { - return f(&question.qName) && - f(&question.qType) && - f(&question.qClass) -} - -func packDomainName(name string, buffer []byte, index int) (newIndex int, ok bool) { - if name == "" { - buffer[index] = 0 - index++ - return index, true - } - - // one more dot plus trailing 0 - if index+len(name)+2 > len(buffer) { - return len(buffer), false - } - - domains := strings.Split(name, ".") - for _, domain := range domains { - domainLen := len(domain) - if domainLen == 0 { - return len(buffer), false - } - buffer[index] = byte(domainLen) - index++ - copy(buffer[index:index+domainLen], domain) - index += domainLen - } - - buffer[index] = 0 - index++ - return index, true -} - -func unpackDomainName(buffer []byte, index int) (name string, newIndex int, ok bool) { - name = "" - - for index < len(buffer) { - cnt := int(buffer[index]) - index++ - if cnt == 0 { - break - } - - if index+cnt > len(buffer) { - return "", len(buffer), false - } - if name != "" { - name += "." - } - name += string(buffer[index : index+cnt]) - index += cnt - } - - if index >= len(buffer) { - return "", len(buffer), false - } - return name, index, true -} - -func packStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) { - ok = any.walk(func(field interface{}) bool { - switch value := field.(type) { - case *uint16: - if index+2 > len(buffer) { - return false - } - binary.BigEndian.PutUint16(buffer[index:index+2], *value) - index += 2 - return true - case *dnsDomainName: - index, ok = packDomainName((*value).name, buffer, index) - return ok - default: - return false - } - }) - - if !ok { - return len(buffer), false - } - return index, true -} - -func unpackStruct(any dnsStruct, buffer []byte, index int) (newIndex int, ok bool) { - ok = any.walk(func(field interface{}) bool { - switch value := field.(type) { - case *uint16: - if index+2 > len(buffer) { - return false - } - *value = binary.BigEndian.Uint16(buffer[index : index+2]) - index += 2 - return true - case *dnsDomainName: - (*value).name, index, ok = unpackDomainName(buffer, index) - return ok - default: - return false - } - }) - - if !ok { - return len(buffer), false - } - return index, true -} - -// Pack the message structure into buffer -func (msg *dnsMsg) packDnsMsg(buffer []byte) (length int, ok bool) { - index := 0 - - if index, ok = packStruct(&msg.header, buffer, index); !ok { - return len(buffer), false - } - - for i := 0; i < len(msg.question); i++ { - if index, ok = packStruct(&msg.question[i], buffer, index); !ok { - return len(buffer), false - } - } - return index, true -} - -// Unpack the buffer into the message structure -func (msg *dnsMsg) unpackDnsMsg(buffer []byte) (ok bool) { - index := 0 - - if index, ok = unpackStruct(&msg.header, buffer, index); !ok { - return false - } - - msg.question = make([]dnsQuestion, msg.header.qdCount) - for i := 0; i < len(msg.question); i++ { - if index, ok = unpackStruct(&msg.question[i], buffer, index); !ok { - return false - } - } - return true -} - // DNS query client classified by address and QTYPE type dnsClientQuery struct { clientAddress string @@ -436,44 +252,85 @@ type dnsClientCache struct { type dnsQueryState struct { searchIndex int32 - msg *dnsMsg + msg *dns.Msg } -func newDnsClientCache() *dnsClientCache { +func newDNSClientCache() *dnsClientCache { return &dnsClientCache{clients: map[dnsClientQuery]*dnsQueryState{}} } -func packetRequiresDnsSuffix(dnsType, dnsClass uint16) bool { +func packetRequiresDNSSuffix(dnsType, dnsClass uint16) bool { return (dnsType == dnsTypeA || dnsType == dnsTypeAAAA) && dnsClass == dnsClassInternet } -func isDnsService(portName string) bool { +func isDNSService(portName string) bool { return portName == dnsPortName } -func appendDnsSuffix(msg *dnsMsg, buffer []byte, length int, dnsSuffix string) int { - if msg == nil || len(msg.question) == 0 { +func appendDNSSuffix(msg *dns.Msg, buffer []byte, length int, dnsSuffix string) int { + if msg == nil || len(msg.Question) == 0 { glog.Warning("DNS message parameter is invalid.") return length } // Save the original name since it will be reused for next iteration - origName := msg.question[0].qName.name + origName := msg.Question[0].Name if dnsSuffix != "" { - msg.question[0].qName.name += "." + dnsSuffix + msg.Question[0].Name += dnsSuffix + "." } - len, ok := msg.packDnsMsg(buffer) - msg.question[0].qName.name = origName + mbuf, err := msg.PackBuffer(buffer) + msg.Question[0].Name = origName - if !ok { - glog.Warning("Unable to pack DNS packet.") + if err != nil { + glog.Warning("Unable to pack DNS packet. Error is: %v", err) return length } - return len + if &buffer[0] != &mbuf[0] { + glog.Warning("Buffer is too small in packing DNS packet.") + return length + } + + return len(mbuf) } -func processUnpackedDnsQueryPacket(dnsClients *dnsClientCache, msg *dnsMsg, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) int { +func recoverDNSQuestion(origName string, msg *dns.Msg, buffer []byte, length int) int { + if msg == nil || len(msg.Question) == 0 { + glog.Warning("DNS message parameter is invalid.") + return length + } + + if origName == msg.Question[0].Name { + return length + } + + msg.Question[0].Name = origName + if len(msg.Answer) > 0 { + msg.Answer[0].Header().Name = origName + } + mbuf, err := msg.PackBuffer(buffer) + + if err != nil { + glog.Warning("Unable to pack DNS packet. Error is: %v", err) + return length + } + + if &buffer[0] != &mbuf[0] { + glog.Warning("Buffer is too small in packing DNS packet.") + return length + } + + return len(mbuf) +} + +func processUnpackedDNSQueryPacket( + dnsClients *dnsClientCache, + msg *dns.Msg, + host string, + dnsQType uint16, + buffer []byte, + length int, + dnsSearch []string) int { if dnsSearch == nil || len(dnsSearch) == 0 { glog.V(1).Infof("DNS search list is not initialized and is empty.") return length @@ -490,22 +347,31 @@ func processUnpackedDnsQueryPacket(dnsClients *dnsClientCache, msg *dnsMsg, host index := atomic.SwapInt32(&state.searchIndex, state.searchIndex+1) // Also update message ID if the client retries due to previous query time out - state.msg.header.id = msg.header.id + state.msg.MsgHdr.Id = msg.MsgHdr.Id if index < 0 || index >= int32(len(dnsSearch)) { glog.V(1).Infof("Search index %d is out of range.", index) return length } - length = appendDnsSuffix(msg, buffer, length, dnsSearch[index]) + length = appendDNSSuffix(msg, buffer, length, dnsSearch[index]) return length } -func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, rcode uint16, host string, dnsQType uint16, buffer []byte, length int, dnsSearch []string) bool { +func processUnpackedDNSResponsePacket( + svrConn net.Conn, + dnsClients *dnsClientCache, + msg *dns.Msg, + rcode int, + host string, + dnsQType uint16, + buffer []byte, + length int, + dnsSearch []string) (bool, int) { var drop bool if dnsSearch == nil || len(dnsSearch) == 0 { glog.V(1).Infof("DNS search list is not initialized and is empty.") - return drop + return drop, length } dnsClients.mu.Lock() @@ -518,7 +384,7 @@ func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCac // If the reponse has failure and iteration through the search list has not // reached the end, retry on behalf of the client using the original query message drop = true - length = appendDnsSuffix(state.msg, buffer, length, dnsSearch[index]) + length = appendDNSSuffix(state.msg, buffer, length, dnsSearch[index]) _, err := svrConn.Write(buffer[0:length]) if err != nil { @@ -527,98 +393,96 @@ func processUnpackedDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCac } } } else { + length = recoverDNSQuestion(state.msg.Question[0].Name, msg, buffer, length) dnsClients.mu.Lock() delete(dnsClients.clients, dnsClientQuery{host, dnsQType}) dnsClients.mu.Unlock() } } - return drop + return drop, length } -func processDnsQueryPacket(dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) int { - msg := &dnsMsg{} - if !msg.unpackDnsMsg(buffer[:length]) { - glog.Warning("Unable to unpack DNS packet.") +func processDNSQueryPacket(dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) int { + msg := &dns.Msg{} + if err := msg.Unpack(buffer[:length]); err != nil { + glog.Warning("Unable to unpack DNS packet. Error is: %v", err) return length } // Query - Response bit that specifies whether this message is a query (0) or a response (1). - qr := msg.header.bits & 0x8000 - if qr != 0 { + if msg.MsgHdr.Response == true { glog.Warning("DNS packet should be a query message.") return length } // QDCOUNT - if msg.header.qdCount != 1 { - glog.V(1).Infof("Number of entries in the question section of the DNS packet is: %d", msg.header.qdCount) + if len(msg.Question) != 1 { + glog.V(1).Infof("Number of entries in the question section of the DNS packet is: %d", len(msg.Question)) glog.V(1).Infof("DNS suffix appending does not support more than one question.") return length } // ANCOUNT, NSCOUNT, ARCOUNT - if msg.header.anCount != 0 || msg.header.nsCount != 0 || msg.header.arCount != 0 { + if len(msg.Answer) != 0 || len(msg.Ns) != 0 || len(msg.Extra) != 0 { glog.V(1).Infof("DNS packet contains more than question section.") return length } - dnsQType := msg.question[0].qType - dnsQClass := msg.question[0].qClass - if packetRequiresDnsSuffix(dnsQType, dnsQClass) { + dnsQType := msg.Question[0].Qtype + dnsQClass := msg.Question[0].Qclass + if packetRequiresDNSSuffix(dnsQType, dnsQClass) { host, _, err := net.SplitHostPort(cliAddr.String()) if err != nil { glog.V(1).Infof("Failed to get host from client address: %v", err) host = cliAddr.String() } - length = processUnpackedDnsQueryPacket(dnsClients, msg, host, dnsQType, buffer, length, dnsSearch) + length = processUnpackedDNSQueryPacket(dnsClients, msg, host, dnsQType, buffer, length, dnsSearch) } return length } -func processDnsResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) bool { +func processDNSResponsePacket(svrConn net.Conn, dnsClients *dnsClientCache, cliAddr net.Addr, buffer []byte, length int, dnsSearch []string) (bool, int) { var drop bool - msg := &dnsMsg{} - if !msg.unpackDnsMsg(buffer[:length]) { - glog.Warning("Unable to unpack DNS packet.") - return drop + msg := &dns.Msg{} + if err := msg.Unpack(buffer[:length]); err != nil { + glog.Warning("Unable to unpack DNS packet. Error is: %v", err) + return drop, length } // Query - Response bit that specifies whether this message is a query (0) or a response (1). - qr := msg.header.bits & 0x8000 - if qr == 0 { + if msg.MsgHdr.Response == false { glog.Warning("DNS packet should be a response message.") - return drop + return drop, length } // QDCOUNT - if msg.header.qdCount != 1 { - glog.V(1).Infof("Number of entries in the reponse section of the DNS packet is: %d", msg.header.qdCount) - return drop + if len(msg.Question) != 1 { + glog.V(1).Infof("Number of entries in the reponse section of the DNS packet is: %d", len(msg.Answer)) + return drop, length } - dnsQType := msg.question[0].qType - dnsQClass := msg.question[0].qClass - if packetRequiresDnsSuffix(dnsQType, dnsQClass) { + dnsQType := msg.Question[0].Qtype + dnsQClass := msg.Question[0].Qclass + if packetRequiresDNSSuffix(dnsQType, dnsQClass) { host, _, err := net.SplitHostPort(cliAddr.String()) if err != nil { glog.V(1).Infof("Failed to get host from client address: %v", err) host = cliAddr.String() } - rcode := msg.header.bits & 0xf - drop = processUnpackedDnsResponsePacket(svrConn, dnsClients, rcode, host, dnsQType, buffer, length, dnsSearch) + drop, length = processUnpackedDNSResponsePacket(svrConn, dnsClients, msg, msg.MsgHdr.Rcode, host, dnsQType, buffer, length, dnsSearch) } - return drop + return drop, length } func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serviceInfo, proxier *Proxier) { var buffer [4096]byte // 4KiB should be enough for most whole-packets var dnsSearch []string - if isDnsService(service.Port) { + if isDNSService(service.Port) { dnsSearch = []string{"", namespaceServiceDomain, serviceDomain, clusterDomain} execer := exec.New() ipconfigInterface := ipconfig.New(execer) @@ -651,8 +515,8 @@ func (udp *udpProxySocket) ProxyLoop(service ServicePortPortalName, myInfo *serv } // If this is DNS query packet - if isDnsService(service.Port) { - n = processDnsQueryPacket(myInfo.dnsClients, cliAddr, buffer[:], n, dnsSearch) + if isDNSService(service.Port) { + n = processDNSQueryPacket(myInfo.dnsClients, cliAddr, buffer[:], n, dnsSearch) } // If this is a client we know already, reuse the connection and goroutine. @@ -720,8 +584,8 @@ func (udp *udpProxySocket) proxyClient(cliAddr net.Addr, svrConn net.Conn, activ } drop := false - if isDnsService(service.Port) { - drop = processDnsResponsePacket(svrConn, dnsClients, cliAddr, buffer[:], n, dnsSearch) + if isDNSService(service.Port) { + drop, n = processDNSResponsePacket(svrConn, dnsClients, cliAddr, buffer[:], n, dnsSearch) } if !drop { diff --git a/pkg/proxy/winuserspace/proxysocket_test.go b/pkg/proxy/winuserspace/proxysocket_test.go deleted file mode 100644 index 66b94fc97b4..00000000000 --- a/pkg/proxy/winuserspace/proxysocket_test.go +++ /dev/null @@ -1,129 +0,0 @@ -/* -Copyright 2017 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package winuserspace - -import ( - "reflect" - "testing" -) - -func TestPackUnpackDnsMsgUnqualifiedName(t *testing.T) { - msg := &dnsMsg{} - var buffer [4096]byte - - msg.header.id = 1 - msg.header.qdCount = 1 - msg.question = make([]dnsQuestion, msg.header.qdCount) - msg.question[0].qClass = 0x01 - msg.question[0].qType = 0x01 - msg.question[0].qName.name = "kubernetes" - - length, ok := msg.packDnsMsg(buffer[:]) - if !ok { - t.Errorf("Pack DNS message failed.") - } - - unpackedMsg := &dnsMsg{} - if !unpackedMsg.unpackDnsMsg(buffer[:length]) { - t.Errorf("Unpack DNS message failed.") - } - - if !reflect.DeepEqual(msg, unpackedMsg) { - t.Errorf("Pack and Unpack DNS message are not consistent.") - } -} - -func TestPackUnpackDnsMsgFqdn(t *testing.T) { - msg := &dnsMsg{} - var buffer [4096]byte - - msg.header.id = 1 - msg.header.qdCount = 1 - msg.question = make([]dnsQuestion, msg.header.qdCount) - msg.question[0].qClass = 0x01 - msg.question[0].qType = 0x01 - msg.question[0].qName.name = "kubernetes.default.svc.cluster.local" - - length, ok := msg.packDnsMsg(buffer[:]) - if !ok { - t.Errorf("Pack DNS message failed.") - } - - unpackedMsg := &dnsMsg{} - if !unpackedMsg.unpackDnsMsg(buffer[:length]) { - t.Errorf("Unpack DNS message failed.") - } - - if !reflect.DeepEqual(msg, unpackedMsg) { - t.Errorf("Pack and Unpack DNS message are not consistent.") - } -} - -func TestPackUnpackDnsMsgEmptyName(t *testing.T) { - msg := &dnsMsg{} - var buffer [4096]byte - - msg.header.id = 1 - msg.header.qdCount = 1 - msg.question = make([]dnsQuestion, msg.header.qdCount) - msg.question[0].qClass = 0x01 - msg.question[0].qType = 0x01 - msg.question[0].qName.name = "" - - length, ok := msg.packDnsMsg(buffer[:]) - if !ok { - t.Errorf("Pack DNS message failed.") - } - - unpackedMsg := &dnsMsg{} - if !unpackedMsg.unpackDnsMsg(buffer[:length]) { - t.Errorf("Unpack DNS message failed.") - } - - if !reflect.DeepEqual(msg, unpackedMsg) { - t.Errorf("Pack and Unpack DNS message are not consistent.") - } -} - -func TestPackUnpackDnsMsgMultipleQuestions(t *testing.T) { - msg := &dnsMsg{} - var buffer [4096]byte - - msg.header.id = 1 - msg.header.qdCount = 2 - msg.question = make([]dnsQuestion, msg.header.qdCount) - msg.question[0].qClass = 0x01 - msg.question[0].qType = 0x01 - msg.question[0].qName.name = "kubernetes" - msg.question[1].qClass = 0x01 - msg.question[1].qType = 0x1c - msg.question[1].qName.name = "kubernetes.default" - - length, ok := msg.packDnsMsg(buffer[:]) - if !ok { - t.Errorf("Pack DNS message failed.") - } - - unpackedMsg := &dnsMsg{} - if !unpackedMsg.unpackDnsMsg(buffer[:length]) { - t.Errorf("Unpack DNS message failed.") - } - - if !reflect.DeepEqual(msg, unpackedMsg) { - t.Errorf("Pack and Unpack DNS message are not consistent.") - } -}