From 9a6c7ed3a11c9d30aac6d4f7da558d4d2235b79e Mon Sep 17 00:00:00 2001 From: Abhishek Shah Date: Wed, 4 May 2016 16:24:21 -0700 Subject: [PATCH] added pkg/dns with unit tests. --- pkg/dns/dns.go | 417 +++++++++++++++++++++++++++++++++++++++++++ pkg/dns/dns_test.go | 377 ++++++++++++++++++++++++++++++++++++++ pkg/dns/treecache.go | 312 ++++++++++++++++++++++++++++++++ 3 files changed, 1106 insertions(+) create mode 100644 pkg/dns/dns.go create mode 100644 pkg/dns/dns_test.go create mode 100644 pkg/dns/treecache.go diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go new file mode 100644 index 00000000000..6b184ce26c8 --- /dev/null +++ b/pkg/dns/dns.go @@ -0,0 +1,417 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dns + +import ( + "encoding/json" + "fmt" + "github.com/golang/glog" + "hash/fnv" + "net" + "strings" + "time" + + etcd "github.com/coreos/etcd/client" + skymsg "github.com/skynetservices/skydns/msg" + kapi "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/api/endpoints" + kcache "k8s.io/kubernetes/pkg/client/cache" + kclient "k8s.io/kubernetes/pkg/client/unversioned" + kframework "k8s.io/kubernetes/pkg/controller/framework" + kselector "k8s.io/kubernetes/pkg/fields" + "k8s.io/kubernetes/pkg/util/validation" + "k8s.io/kubernetes/pkg/util/wait" +) + +const ( + kubernetesSvcName = "kubernetes" + + // A subdomain added to the user specified domain for all services. + serviceSubdomain = "svc" + + // A subdomain added to the user specified dmoain for all pods. + podSubdomain = "pod" + + // Resync period for the kube controller loop. + resyncPeriod = 30 * time.Minute +) + +type KubeDNS struct { + kubeClient *kclient.Client + // DNS domain name. + domain string + // A cache that contains all the endpoints in the system. + endpointsStore kcache.Store + // A cache that contains all the services in the system. + servicesStore kcache.Store + cache *TreeCache + domainPath []string + eController *kframework.Controller + serviceController *kframework.Controller +} + +func NewKubeDNS(client *kclient.Client, domain string) *KubeDNS { + kd := &KubeDNS{ + kubeClient: client, + domain: domain, + cache: NewTreeCache(), + domainPath: reverseArray(strings.Split(strings.TrimRight(domain, "."), ".")), + } + kd.setEndpointsStore() + kd.setServicesStore() + return kd +} + +func (kd *KubeDNS) Start() { + go kd.eController.Run(wait.NeverStop) + go kd.serviceController.Run(wait.NeverStop) + // Wait synchronously for the Kubernetes service and add a DNS record for it. + // TODO (abshah) UNCOMMENT AFTER TEST COMPLETE + //kd.waitForKubernetesService() +} + +func (kd *KubeDNS) waitForKubernetesService() (svc *kapi.Service) { + name := fmt.Sprintf("%v/%v", kapi.NamespaceDefault, kubernetesSvcName) + glog.Infof("Waiting for service: %v", name) + var err error + servicePollInterval := 1 * time.Second + for { + svc, err = kd.kubeClient.Services(kapi.NamespaceDefault).Get(kubernetesSvcName) + if err != nil || svc == nil { + glog.Infof("Ignoring error while waiting for service %v: %v. Sleeping %v before retrying.", name, err, servicePollInterval) + time.Sleep(servicePollInterval) + continue + } + break + } + return +} + +func (kd *KubeDNS) GetCacheAsJSON() string { + json, _ := kd.cache.Serialize("") + return json +} + +func (kd *KubeDNS) setServicesStore() { + // Returns a cache.ListWatch that gets all changes to services. + serviceWatch := kcache.NewListWatchFromClient(kd.kubeClient, "services", kapi.NamespaceAll, kselector.Everything()) + kd.servicesStore, kd.serviceController = kframework.NewInformer( + serviceWatch, + &kapi.Service{}, + resyncPeriod, + kframework.ResourceEventHandlerFuncs{ + AddFunc: kd.newService, + DeleteFunc: kd.removeService, + UpdateFunc: kd.updateService, + }, + ) +} + +func (kd *KubeDNS) setEndpointsStore() { + // Returns a cache.ListWatch that gets all changes to endpoints. + endpointsWatch := kcache.NewListWatchFromClient(kd.kubeClient, "endpoints", kapi.NamespaceAll, kselector.Everything()) + kd.endpointsStore, kd.eController = kframework.NewInformer( + endpointsWatch, + &kapi.Endpoints{}, + resyncPeriod, + kframework.ResourceEventHandlerFuncs{ + AddFunc: kd.handleEndpointAdd, + UpdateFunc: func(oldObj, newObj interface{}) { + // TODO: Avoid unwanted updates. + kd.handleEndpointAdd(newObj) + }, + }, + ) +} + +func (kd *KubeDNS) newService(obj interface{}) { + if service, ok := obj.(*kapi.Service); ok { + // if ClusterIP is not set, a DNS entry should not be created + if !kapi.IsServiceIPSet(service) { + kd.newHeadlessService(service) + return + } + if len(service.Spec.Ports) == 0 { + glog.Info("Unexpected service with no ports, this should not have happend: %v", service) + } + kd.newPortalService(service) + } +} + +func (kd *KubeDNS) removeService(obj interface{}) { + if s, ok := obj.(*kapi.Service); ok { + subCachePath := append(kd.domainPath, serviceSubdomain, s.Namespace, s.Name) + kd.cache.DeletePath(subCachePath...) + } +} + +func (kd *KubeDNS) updateService(oldObj, newObj interface{}) { + kd.newService(newObj) +} + +func (kd *KubeDNS) handleEndpointAdd(obj interface{}) { + if e, ok := obj.(*kapi.Endpoints); ok { + kd.addDNSUsingEndpoints(e) + } +} + +func (kd *KubeDNS) addDNSUsingEndpoints(e *kapi.Endpoints) error { + svc, err := kd.getServiceFromEndpoints(e) + if err != nil { + return err + } + if svc == nil || kapi.IsServiceIPSet(svc) { + // No headless service found corresponding to endpoints object. + return nil + } + return kd.generateRecordsForHeadlessService(e, svc) +} + +func (kd *KubeDNS) getServiceFromEndpoints(e *kapi.Endpoints) (*kapi.Service, error) { + key, err := kcache.MetaNamespaceKeyFunc(e) + if err != nil { + return nil, err + } + obj, exists, err := kd.servicesStore.GetByKey(key) + if err != nil { + return nil, fmt.Errorf("failed to get service object from services store - %v", err) + } + if !exists { + glog.V(1).Infof("could not find service for endpoint %q in namespace %q", e.Name, e.Namespace) + return nil, nil + } + if svc, ok := obj.(*kapi.Service); ok { + return svc, nil + } + return nil, fmt.Errorf("got a non service object in services store %v", obj) +} + +func (kd *KubeDNS) newPortalService(service *kapi.Service) { + subCache := NewTreeCache() + recordValue, recordLabel := getSkyMsg(service.Spec.ClusterIP, 0) + subCache.SetEntry(recordLabel, recordValue) + + // Generate SRV Records + for i := range service.Spec.Ports { + port := &service.Spec.Ports[i] + if port.Name != "" && port.Protocol != "" { + srvValue := kd.generateSRVRecordValue(service, int(port.Port)) + subCache.SetEntry(recordLabel, srvValue, "_"+strings.ToLower(string(port.Protocol)), "_"+port.Name) + } + } + subCachePath := append(kd.domainPath, serviceSubdomain, service.Namespace) + kd.cache.SetSubCache(service.Name, subCache, subCachePath...) +} + +func (kd *KubeDNS) generateRecordsForHeadlessService(e *kapi.Endpoints, svc *kapi.Service) error { + // TODO: remove this after v1.4 is released and the old annotations are EOL + podHostnames, err := getPodHostnamesFromAnnotation(e.Annotations) + if err != nil { + return err + } + subCache := NewTreeCache() + glog.V(4).Infof("Endpoints Annotations: %v", e.Annotations) + for idx := range e.Subsets { + for subIdx := range e.Subsets[idx].Addresses { + address := &e.Subsets[idx].Addresses[subIdx] + endpointIP := address.IP + recordValue, endpointName := getSkyMsg(endpointIP, 0) + if hostLabel, exists := getHostname(address, podHostnames); exists { + endpointName = hostLabel + } + subCache.SetEntry(endpointName, recordValue) + for portIdx := range e.Subsets[idx].Ports { + endpointPort := &e.Subsets[idx].Ports[portIdx] + if endpointPort.Name != "" && endpointPort.Protocol != "" { + srvValue := kd.generateSRVRecordValue(svc, int(endpointPort.Port), endpointName) + subCache.SetEntry(endpointName, srvValue, "_"+strings.ToLower(string(endpointPort.Protocol)), "_"+endpointPort.Name) + } + } + } + } + subCachePath := append(kd.domainPath, serviceSubdomain, svc.Namespace) + kd.cache.SetSubCache(svc.Name, subCache, subCachePath...) + return nil +} + +func getHostname(address *kapi.EndpointAddress, podHostnames map[string]endpoints.HostRecord) (string, bool) { + if len(address.Hostname) > 0 { + return address.Hostname, true + } + if hostRecord, exists := podHostnames[address.IP]; exists && validation.IsDNS1123Label(hostRecord.HostName) { + return hostRecord.HostName, true + } + return "", false +} + +func getPodHostnamesFromAnnotation(annotations map[string]string) (map[string]endpoints.HostRecord, error) { + hostnames := map[string]endpoints.HostRecord{} + + if annotations != nil { + if serializedHostnames, exists := annotations[endpoints.PodHostnamesAnnotation]; exists && len(serializedHostnames) > 0 { + err := json.Unmarshal([]byte(serializedHostnames), &hostnames) + if err != nil { + return nil, err + } + } + } + return hostnames, nil +} + +func (kd *KubeDNS) generateSRVRecordValue(svc *kapi.Service, portNumber int, cNameLabels ...string) *skymsg.Service { + cName := strings.Join([]string{svc.Name, svc.Namespace, serviceSubdomain, kd.domain}, ".") + for _, cNameLabel := range cNameLabels { + cName = cNameLabel + "." + cName + } + recordValue, _ := getSkyMsg(cName, portNumber) + return recordValue +} + +// Generates skydns records for a headless service. +func (kd *KubeDNS) newHeadlessService(service *kapi.Service) error { + // Create an A record for every pod in the service. + // This record must be periodically updated. + // Format is as follows: + // For a service x, with pods a and b create DNS records, + // a.x.ns.domain. and, b.x.ns.domain. + key, err := kcache.MetaNamespaceKeyFunc(service) + if err != nil { + return err + } + e, exists, err := kd.endpointsStore.GetByKey(key) + if err != nil { + return fmt.Errorf("failed to get endpoints object from endpoints store - %v", err) + } + if !exists { + glog.V(1).Infof("Could not find endpoints for service %q in namespace %q. DNS records will be created once endpoints show up.", service.Name, service.Namespace) + return nil + } + if e, ok := e.(*kapi.Endpoints); ok { + return kd.generateRecordsForHeadlessService(e, service) + } + return nil +} + +func (kd *KubeDNS) Records(name string, exact bool) ([]skymsg.Service, error) { + glog.Infof("Received DNS Request:%s, exact:%v", name, exact) + trimmed := strings.TrimRight(name, ".") + segments := strings.Split(trimmed, ".") + path := reverseArray(segments) + if kd.isPodRecord(path) { + response, err := kd.getPodRecord(path) + if err == nil { + return []skymsg.Service{*response}, nil + } + return nil, err + } + + if exact { + key := path[len(path)-1] + if key == "" { + return []skymsg.Service{}, nil + } + if record, ok := kd.cache.GetEntry(key, path[:len(path)-1]...); ok { + return []skymsg.Service{*(record.(*skymsg.Service))}, nil + } + return nil, etcd.Error{Code: etcd.ErrorCodeKeyNotFound} + } + + // tmp, _ := kd.cache.Serialize("") + // glog.Infof("Searching path:%q, %v", path, tmp) + records := kd.cache.GetValuesForPathWithRegex(path...) + retval := []skymsg.Service{} + for _, val := range records { + retval = append(retval, *(val.(*skymsg.Service))) + } + glog.Infof("records:%v, retval:%v, path:%v", records, retval, path) + if len(retval) == 0 { + return nil, etcd.Error{Code: etcd.ErrorCodeKeyNotFound} + } + return retval, nil +} + +func (kd *KubeDNS) ReverseRecord(name string) (*skymsg.Service, error) { + glog.Infof("Received ReverseRecord Request:%s", name) + + segments := strings.Split(strings.TrimRight(name, "."), ".") + + for _, k := range segments { + if k == "*" || k == "any" { + return nil, fmt.Errorf("reverse can not contain wildcards") + } + } + + return nil, fmt.Errorf("must be exactly one service record") +} + +// e.g {"local", "cluster", "pod", "default", "10-0-0-1"} +func (kd *KubeDNS) isPodRecord(path []string) bool { + if len(path) != len(kd.domainPath)+3 { + return false + } + if path[len(kd.domainPath)] != "pod" { + return false + } + for _, segment := range path { + if segment == "*" { + return false + } + } + return true +} + +func (kd *KubeDNS) getPodRecord(path []string) (*skymsg.Service, error) { + ipStr := path[len(path)-1] + ip := strings.Replace(ipStr, "-", ".", -1) + if parsed := net.ParseIP(ip); parsed != nil { + msg := &skymsg.Service{ + Host: ip, + Port: 0, + Priority: 10, + Weight: 10, + Ttl: 30, + } + return msg, nil + } + return nil, fmt.Errorf("Invalid IP Address %v", ip) +} + +// Returns record in a format that SkyDNS understands. +// Also return the hash of the record. +func getSkyMsg(ip string, port int) (*skymsg.Service, string) { + msg := &skymsg.Service{ + Host: ip, + Port: port, + Priority: 10, + Weight: 10, + Ttl: 30, + } + s := fmt.Sprintf("%v", msg) + h := fnv.New32a() + h.Write([]byte(s)) + hash := fmt.Sprintf("%x", h.Sum32()) + glog.Infof("DNS Record:%s, hash:%s", s, hash) + return msg, fmt.Sprintf("%x", hash) +} + +func reverseArray(arr []string) []string { + for i := 0; i < len(arr)/2; i++ { + j := len(arr) - i - 1 + arr[i], arr[j] = arr[j], arr[i] + } + return arr +} diff --git a/pkg/dns/dns_test.go b/pkg/dns/dns_test.go new file mode 100644 index 00000000000..9a7a6e61e90 --- /dev/null +++ b/pkg/dns/dns_test.go @@ -0,0 +1,377 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dns + +import ( + "fmt" + "strings" + "testing" + + skymsg "github.com/skynetservices/skydns/msg" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + kapi "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/client/cache" + "net" +) + +const ( + testDomain = "cluster.local." + basePath = "/skydns/local/cluster" + serviceSubDomain = "svc" + podSubDomain = "pod" + testService = "testservice" + testNamespace = "default" +) + +func newKubeDNS() *KubeDNS { + kd := &KubeDNS{ + domain: testDomain, + endpointsStore: cache.NewStore(cache.MetaNamespaceKeyFunc), + servicesStore: cache.NewStore(cache.MetaNamespaceKeyFunc), + cache: NewTreeCache(), + domainPath: reverseArray(strings.Split(strings.TrimRight(testDomain, "."), ".")), + } + return kd +} + +func TestPodDns(t *testing.T) { + const ( + testPodIP = "1.2.3.4" + sanitizedPodIP = "1-2-3-4" + testPodName = "testPod" + ) + kd := newKubeDNS() + + records, err := kd.Records(sanitizedPodIP+".default.pod."+kd.domain, false) + require.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, testPodIP, records[0].Host) +} + +func TestUnnamedSinglePortService(t *testing.T) { + kd := newKubeDNS() + s := newService(testNamespace, testService, "1.2.3.4", "", 80) + // Add the service + kd.newService(s) + assertDNSForClusterIP(t, kd, s) + // Delete the service + kd.removeService(s) + assertNoDNSForClusterIP(t, kd, s) +} + +func TestNamedSinglePortService(t *testing.T) { + const ( + portName1 = "http1" + portName2 = "http2" + ) + kd := newKubeDNS() + s := newService(testNamespace, testService, "1.2.3.4", portName1, 80) + // Add the service + kd.newService(s) + assertDNSForClusterIP(t, kd, s) + assertSRVForNamedPort(t, kd, s, portName1) + + newService := *s + // update the portName of the service + newService.Spec.Ports[0].Name = portName2 + kd.updateService(s, &newService) + assertDNSForClusterIP(t, kd, s) + assertSRVForNamedPort(t, kd, s, portName2) + assertNoSRVForNamedPort(t, kd, s, portName1) + + // Delete the service + kd.removeService(s) + assertNoDNSForClusterIP(t, kd, s) + assertNoSRVForNamedPort(t, kd, s, portName1) + assertNoSRVForNamedPort(t, kd, s, portName2) +} + +func TestHeadlessService(t *testing.T) { + kd := newKubeDNS() + s := newHeadlessService() + assert.NoError(t, kd.servicesStore.Add(s)) + endpoints := newEndpoints(s, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2"), newSubsetWithOnePort("", 8080, "10.0.0.3", "10.0.0.4")) + + assert.NoError(t, kd.endpointsStore.Add(endpoints)) + kd.newService(s) + assertDNSForHeadlessService(t, kd, endpoints) + kd.removeService(s) + assertNoDNSForHeadlessService(t, kd, s) +} + +func TestHeadlessServiceWithNamedPorts(t *testing.T) { + kd := newKubeDNS() + service := newHeadlessService() + // add service to store + assert.NoError(t, kd.servicesStore.Add(service)) + endpoints := newEndpoints(service, newSubsetWithTwoPorts("http1", 80, "http2", 81, "10.0.0.1", "10.0.0.2"), + newSubsetWithOnePort("https", 443, "10.0.0.3", "10.0.0.4")) + + // We expect 10 records. 6 SRV records. 4 POD records. + // add endpoints + assert.NoError(t, kd.endpointsStore.Add(endpoints)) + + // add service + kd.newService(service) + assertDNSForHeadlessService(t, kd, endpoints) + assertSRVForHeadlessService(t, kd, service, endpoints) + + // reduce endpoints + endpoints.Subsets = endpoints.Subsets[:1] + kd.handleEndpointAdd(endpoints) + // We expect 6 records. 4 SRV records. 2 POD records. + assertDNSForHeadlessService(t, kd, endpoints) + assertSRVForHeadlessService(t, kd, service, endpoints) + + kd.removeService(service) + assertNoDNSForHeadlessService(t, kd, service) +} + +func TestHeadlessServiceEndpointsUpdate(t *testing.T) { + kd := newKubeDNS() + service := newHeadlessService() + // add service to store + assert.NoError(t, kd.servicesStore.Add(service)) + + endpoints := newEndpoints(service, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2")) + // add endpoints to store + assert.NoError(t, kd.endpointsStore.Add(endpoints)) + + // add service + kd.newService(service) + assertDNSForHeadlessService(t, kd, endpoints) + + // increase endpoints + endpoints.Subsets = append(endpoints.Subsets, + newSubsetWithOnePort("", 8080, "10.0.0.3", "10.0.0.4"), + ) + // expected DNSRecords = 4 + kd.handleEndpointAdd(endpoints) + assertDNSForHeadlessService(t, kd, endpoints) + + // remove all endpoints + endpoints.Subsets = []kapi.EndpointSubset{} + kd.handleEndpointAdd(endpoints) + assertNoDNSForHeadlessService(t, kd, service) + + // remove service + kd.removeService(service) + assertNoDNSForHeadlessService(t, kd, service) +} + +func TestHeadlessServiceWithDelayedEndpointsAddition(t *testing.T) { + kd := newKubeDNS() + // create service + service := newHeadlessService() + + // add service to store + assert.NoError(t, kd.servicesStore.Add(service)) + + // add service + kd.newService(service) + assertNoDNSForHeadlessService(t, kd, service) + + // create endpoints + endpoints := newEndpoints(service, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2")) + + // add endpoints to store + assert.NoError(t, kd.endpointsStore.Add(endpoints)) + + // add endpoints + kd.handleEndpointAdd(endpoints) + + assertDNSForHeadlessService(t, kd, endpoints) + + // remove service + kd.removeService(service) + assertNoDNSForHeadlessService(t, kd, service) +} + +func newService(namespace, serviceName, clusterIP, portName string, portNumber int32) *kapi.Service { + service := kapi.Service{ + ObjectMeta: kapi.ObjectMeta{ + Name: serviceName, + Namespace: namespace, + }, + Spec: kapi.ServiceSpec{ + ClusterIP: clusterIP, + Ports: []kapi.ServicePort{ + {Port: portNumber, Name: portName, Protocol: "TCP"}, + }, + }, + } + return &service +} + +func newHeadlessService() *kapi.Service { + service := kapi.Service{ + ObjectMeta: kapi.ObjectMeta{ + Name: testService, + Namespace: testNamespace, + }, + Spec: kapi.ServiceSpec{ + ClusterIP: "None", + Ports: []kapi.ServicePort{ + {Port: 0}, + }, + }, + } + return &service +} + +func newEndpoints(service *kapi.Service, subsets ...kapi.EndpointSubset) *kapi.Endpoints { + endpoints := kapi.Endpoints{ + ObjectMeta: service.ObjectMeta, + Subsets: []kapi.EndpointSubset{}, + } + + for _, subset := range subsets { + endpoints.Subsets = append(endpoints.Subsets, subset) + } + return &endpoints +} + +func newSubsetWithOnePort(portName string, port int32, ips ...string) kapi.EndpointSubset { + subset := newSubset() + subset.Ports = append(subset.Ports, kapi.EndpointPort{Port: port, Name: portName, Protocol: "TCP"}) + for _, ip := range ips { + subset.Addresses = append(subset.Addresses, kapi.EndpointAddress{IP: ip}) + } + return subset +} + +func newSubsetWithTwoPorts(portName1 string, portNumber1 int32, portName2 string, portNumber2 int32, ips ...string) kapi.EndpointSubset { + subset := newSubsetWithOnePort(portName1, portNumber1, ips...) + subset.Ports = append(subset.Ports, kapi.EndpointPort{Port: portNumber2, Name: portName2, Protocol: "TCP"}) + return subset +} + +func newSubset() kapi.EndpointSubset { + subset := kapi.EndpointSubset{ + Addresses: []kapi.EndpointAddress{}, + Ports: []kapi.EndpointPort{}, + } + return subset +} + +func assertSRVForHeadlessService(t *testing.T, kd *KubeDNS, s *kapi.Service, e *kapi.Endpoints) { + for _, subset := range e.Subsets { + for _, port := range subset.Ports { + records, err := kd.Records(getSRVFQDN(kd, s, port.Name), false) + require.NoError(t, err) + assertRecordPortsMatchPort(t, port.Port, records) + assertCNameRecordsMatchEndpointIPs(t, kd, subset.Addresses, records) + } + } +} + +func assertDNSForHeadlessService(t *testing.T, kd *KubeDNS, e *kapi.Endpoints) { + records, err := kd.Records(getEndpointsFQDN(kd, e), false) + require.NoError(t, err) + endpoints := map[string]bool{} + for _, subset := range e.Subsets { + for _, endpointAddress := range subset.Addresses { + endpoints[endpointAddress.IP] = true + } + } + assert.Equal(t, len(endpoints), len(records)) + for _, record := range records { + _, found := endpoints[record.Host] + assert.True(t, found) + } +} + +func assertRecordPortsMatchPort(t *testing.T, port int32, records []skymsg.Service) { + for _, record := range records { + assert.Equal(t, port, int32(record.Port)) + } +} + +func assertCNameRecordsMatchEndpointIPs(t *testing.T, kd *KubeDNS, e []kapi.EndpointAddress, records []skymsg.Service) { + endpoints := map[string]bool{} + for _, endpointAddress := range e { + endpoints[endpointAddress.IP] = true + } + assert.Equal(t, len(e), len(records), "unexpected record count") + for _, record := range records { + _, found := endpoints[getIPForCName(t, kd, record.Host)] + assert.True(t, found, "Did not endpoint with address:%s", record.Host) + } +} + +func getIPForCName(t *testing.T, kd *KubeDNS, cname string) string { + records, err := kd.Records(cname, false) + require.NoError(t, err) + assert.Equal(t, 1, len(records), "Could not get IP for CNAME record for %s", cname) + assert.NotNil(t, net.ParseIP(records[0].Host), "Invalid IP address %q", records[0].Host) + return records[0].Host +} + +func assertNoDNSForHeadlessService(t *testing.T, kd *KubeDNS, s *kapi.Service) { + records, err := kd.Records(getServiceFQDN(kd, s), false) + require.Error(t, err) + assert.Equal(t, 0, len(records)) +} + +func assertSRVForNamedPort(t *testing.T, kd *KubeDNS, s *kapi.Service, portName string) { + records, err := kd.Records(getSRVFQDN(kd, s, portName), false) + require.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, getServiceFQDN(kd, s), records[0].Host) +} + +func assertNoSRVForNamedPort(t *testing.T, kd *KubeDNS, s *kapi.Service, portName string) { + records, err := kd.Records(getSRVFQDN(kd, s, portName), false) + require.Error(t, err) + assert.Equal(t, 0, len(records)) +} + +func assertNoDNSForClusterIP(t *testing.T, kd *KubeDNS, s *kapi.Service) { + records, err := kd.Records(getServiceFQDN(kd, s), false) + require.Error(t, err) + assert.Equal(t, 0, len(records)) +} + +func assertDNSForClusterIP(t *testing.T, kd *KubeDNS, s *kapi.Service) { + serviceFQDN := getServiceFQDN(kd, s) + queries := []string{ + serviceFQDN, + strings.Replace(serviceFQDN, ".svc.", ".*.", 1), + strings.Replace(serviceFQDN, s.Namespace, "*", 1), + strings.Replace(strings.Replace(serviceFQDN, s.Namespace, "*", 1), ".svc.", ".*.", 1), + "*." + serviceFQDN, + } + for _, query := range queries { + records, err := kd.Records(query, false) + require.NoError(t, err) + assert.Equal(t, 1, len(records)) + assert.Equal(t, s.Spec.ClusterIP, records[0].Host) + } +} + +func getServiceFQDN(kd *KubeDNS, s *kapi.Service) string { + return fmt.Sprintf("%s.%s.svc.%s", s.Name, s.Namespace, kd.domain) +} + +func getEndpointsFQDN(kd *KubeDNS, e *kapi.Endpoints) string { + return fmt.Sprintf("%s.%s.svc.%s", e.ObjectMeta.Name, e.ObjectMeta.Namespace, kd.domain) +} + +func getSRVFQDN(kd *KubeDNS, s *kapi.Service, portName string) string { + return fmt.Sprintf("_%s._tcp.%s.%s.svc.%s", portName, s.Name, s.Namespace, kd.domain) +} diff --git a/pkg/dns/treecache.go b/pkg/dns/treecache.go new file mode 100644 index 00000000000..ee3ba206ee7 --- /dev/null +++ b/pkg/dns/treecache.go @@ -0,0 +1,312 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dns + +import ( + "bytes" + "crypto/md5" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path" + "reflect" + "strings" + "sync" +) + +const ( + dataFile = "data.dat" + crcFile = "data.crc" +) + +type object interface{} + +type TreeCache struct { + ChildNodes map[string]*TreeCache + Entries map[string]interface{} + m *sync.RWMutex +} + +func NewTreeCache() *TreeCache { + return &TreeCache{ + ChildNodes: make(map[string]*TreeCache), + Entries: make(map[string]interface{}), + m: &sync.RWMutex{}, + } +} + +func Deserialize(dir string) (*TreeCache, error) { + b, err := ioutil.ReadFile(path.Join(dir, dataFile)) + if err != nil { + return nil, err + } + var hash []byte + hash, err = ioutil.ReadFile(path.Join(dir, crcFile)) + if err != nil { + return nil, err + } + if !reflect.DeepEqual(hash, getMD5(b)) { + return nil, fmt.Errorf("Checksum failed") + } + + var cache TreeCache + err = json.Unmarshal(b, &cache) + if err != nil { + return nil, err + } + cache.m = &sync.RWMutex{} + return &cache, nil +} + +func (cache *TreeCache) Serialize(dir string) (string, error) { + cache.m.RLock() + defer cache.m.RUnlock() + b, err := json.Marshal(cache) + if err != nil { + return "", err + } + + if len(dir) == 0 { + var prettyJSON bytes.Buffer + err = json.Indent(&prettyJSON, b, "", "\t") + + if err != nil { + return "", err + } + return string(prettyJSON.Bytes()), nil + } + if err := ensureDir(dir, os.FileMode(0755)); err != nil { + return "", err + } + if err := ioutil.WriteFile(path.Join(dir, dataFile), b, 0644); err != nil { + return "", err + } + if err := ioutil.WriteFile(path.Join(dir, crcFile), getMD5(b), 0644); err != nil { + return "", err + } + return string(b), nil +} + +func (cache *TreeCache) SetEntry(key string, val interface{}, path ...string) { + cache.m.Lock() + defer cache.m.Unlock() + node := cache.ensureChildNode(path...) + node.Entries[key] = val +} + +func (cache *TreeCache) ReplaceEntries(entries map[string]interface{}, path ...string) { + cache.m.Lock() + defer cache.m.Unlock() + node := cache.ensureChildNode(path...) + node.Entries = make(map[string]interface{}) + for key, val := range entries { + node.Entries[key] = val + } +} + +func (cache *TreeCache) GetSubCache(path ...string) *TreeCache { + childCache := cache + for _, subpath := range path { + childCache = childCache.ChildNodes[subpath] + if childCache == nil { + return childCache + } + } + return childCache +} + +func (cache *TreeCache) SetSubCache(key string, subCache *TreeCache, path ...string) { + cache.m.Lock() + defer cache.m.Unlock() + node := cache.ensureChildNode(path...) + node.ChildNodes[key] = subCache +} + +func (cache *TreeCache) GetEntry(key string, path ...string) (interface{}, bool) { + cache.m.RLock() + defer cache.m.RUnlock() + childNode := cache.GetSubCache(path...) + val, ok := childNode.Entries[key] + return val, ok +} + +func (cache *TreeCache) GetValuesForPathWithRegex(path ...string) []interface{} { + cache.m.RLock() + defer cache.m.RUnlock() + retval := []interface{}{} + nodesToExplore := []*TreeCache{cache} + for idx, subpath := range path { + nextNodesToExplore := []*TreeCache{} + if idx == len(path)-1 { + // if path ends on an entry, instead of a child node, add the entry + for _, node := range nodesToExplore { + if subpath == "*" || subpath == "any" { + nextNodesToExplore = append(nextNodesToExplore, node) + } else { + if val, ok := node.Entries[subpath]; ok { + retval = append(retval, val) + } else { + childNode := node.ChildNodes[subpath] + if childNode != nil { + nextNodesToExplore = append(nextNodesToExplore, childNode) + } + } + } + } + nodesToExplore = nextNodesToExplore + break + } + + if subpath == "*" || subpath == "any" { + for _, node := range nodesToExplore { + for subkey, subnode := range node.ChildNodes { + if !strings.HasPrefix(subkey, "_") { + nextNodesToExplore = append(nextNodesToExplore, subnode) + } + } + } + } else { + for _, node := range nodesToExplore { + childNode := node.ChildNodes[subpath] + if childNode != nil { + nextNodesToExplore = append(nextNodesToExplore, childNode) + } + } + } + nodesToExplore = nextNodesToExplore + } + + for _, node := range nodesToExplore { + for _, val := range node.Entries { + retval = append(retval, val) + } + } + + return retval +} + +func (cache *TreeCache) GetEntries(recursive bool, path ...string) []interface{} { + cache.m.RLock() + defer cache.m.RUnlock() + childNode := cache.GetSubCache(path...) + if childNode == nil { + return nil + } + + retval := [][]interface{}{{}} + childNode.appendValues(recursive, retval) + return retval[0] +} + +func (cache *TreeCache) DeletePath(path ...string) bool { + if len(path) == 0 { + return false + } + cache.m.Lock() + defer cache.m.Unlock() + if parentNode := cache.GetSubCache(path[:len(path)-1]...); parentNode != nil { + if _, ok := parentNode.ChildNodes[path[len(path)-1]]; ok { + delete(parentNode.ChildNodes, path[len(path)-1]) + return true + } + } + return false +} + +func (tn *TreeCache) DeleteEntry(key string, path ...string) bool { + tn.m.Lock() + defer tn.m.Unlock() + childNode := tn.GetSubCache(path...) + if childNode == nil { + return false + } + if _, ok := childNode.Entries[key]; ok { + delete(childNode.Entries, key) + return true + } + return false +} + +func (tn *TreeCache) appendValues(recursive bool, ref [][]interface{}) { + for _, value := range tn.Entries { + ref[0] = append(ref[0], value) + } + if recursive { + for _, node := range tn.ChildNodes { + node.appendValues(recursive, ref) + } + } +} + +func (tn *TreeCache) ensureChildNode(path ...string) *TreeCache { + childNode := tn + for _, subpath := range path { + newNode := childNode.ChildNodes[subpath] + if newNode == nil { + newNode = NewTreeCache() + childNode.ChildNodes[subpath] = newNode + } + childNode = newNode + } + return childNode +} + +func ensureDir(path string, perm os.FileMode) error { + s, err := os.Stat(path) + if err != nil || !s.IsDir() { + return os.Mkdir(path, perm) + } + return nil +} + +func getMD5(b []byte) []byte { + h := md5.New() + h.Write(b) + return []byte(fmt.Sprintf("%x", h.Sum(nil))) +} + +func main() { + root := NewTreeCache() + fmt.Println("Adding Entries") + root.SetEntry("k", "v") + root.SetEntry("foo", "bar", "local") + root.SetEntry("foo1", "bar1", "local", "cluster") + + fmt.Println("Fetching Entries") + for _, entry := range root.GetEntries(true, "local") { + fmt.Printf("%s\n", entry) + } + + fmt.Println("Serializing") + if _, err := root.Serialize("./foo"); err != nil { + fmt.Printf("Serialization Error: %v,\n", err) + return + } + + fmt.Println("Deserializing") + tn, err := Deserialize("./foo") + if err != nil { + fmt.Printf("Deserialization Error: %v\n", err) + return + } + + fmt.Println("Fetching Entries") + for _, entry := range tn.GetEntries(true, "local") { + fmt.Printf("%s\n", entry) + } +}