added pkg/dns with unit tests.

This commit is contained in:
Abhishek Shah 2016-05-04 16:24:21 -07:00
parent d2dd4911ca
commit 9a6c7ed3a1
3 changed files with 1106 additions and 0 deletions

417
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,417 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dns
import (
"encoding/json"
"fmt"
"github.com/golang/glog"
"hash/fnv"
"net"
"strings"
"time"
etcd "github.com/coreos/etcd/client"
skymsg "github.com/skynetservices/skydns/msg"
kapi "k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/api/endpoints"
kcache "k8s.io/kubernetes/pkg/client/cache"
kclient "k8s.io/kubernetes/pkg/client/unversioned"
kframework "k8s.io/kubernetes/pkg/controller/framework"
kselector "k8s.io/kubernetes/pkg/fields"
"k8s.io/kubernetes/pkg/util/validation"
"k8s.io/kubernetes/pkg/util/wait"
)
const (
kubernetesSvcName = "kubernetes"
// A subdomain added to the user specified domain for all services.
serviceSubdomain = "svc"
// A subdomain added to the user specified dmoain for all pods.
podSubdomain = "pod"
// Resync period for the kube controller loop.
resyncPeriod = 30 * time.Minute
)
type KubeDNS struct {
kubeClient *kclient.Client
// DNS domain name.
domain string
// A cache that contains all the endpoints in the system.
endpointsStore kcache.Store
// A cache that contains all the services in the system.
servicesStore kcache.Store
cache *TreeCache
domainPath []string
eController *kframework.Controller
serviceController *kframework.Controller
}
func NewKubeDNS(client *kclient.Client, domain string) *KubeDNS {
kd := &KubeDNS{
kubeClient: client,
domain: domain,
cache: NewTreeCache(),
domainPath: reverseArray(strings.Split(strings.TrimRight(domain, "."), ".")),
}
kd.setEndpointsStore()
kd.setServicesStore()
return kd
}
func (kd *KubeDNS) Start() {
go kd.eController.Run(wait.NeverStop)
go kd.serviceController.Run(wait.NeverStop)
// Wait synchronously for the Kubernetes service and add a DNS record for it.
// TODO (abshah) UNCOMMENT AFTER TEST COMPLETE
//kd.waitForKubernetesService()
}
func (kd *KubeDNS) waitForKubernetesService() (svc *kapi.Service) {
name := fmt.Sprintf("%v/%v", kapi.NamespaceDefault, kubernetesSvcName)
glog.Infof("Waiting for service: %v", name)
var err error
servicePollInterval := 1 * time.Second
for {
svc, err = kd.kubeClient.Services(kapi.NamespaceDefault).Get(kubernetesSvcName)
if err != nil || svc == nil {
glog.Infof("Ignoring error while waiting for service %v: %v. Sleeping %v before retrying.", name, err, servicePollInterval)
time.Sleep(servicePollInterval)
continue
}
break
}
return
}
func (kd *KubeDNS) GetCacheAsJSON() string {
json, _ := kd.cache.Serialize("")
return json
}
func (kd *KubeDNS) setServicesStore() {
// Returns a cache.ListWatch that gets all changes to services.
serviceWatch := kcache.NewListWatchFromClient(kd.kubeClient, "services", kapi.NamespaceAll, kselector.Everything())
kd.servicesStore, kd.serviceController = kframework.NewInformer(
serviceWatch,
&kapi.Service{},
resyncPeriod,
kframework.ResourceEventHandlerFuncs{
AddFunc: kd.newService,
DeleteFunc: kd.removeService,
UpdateFunc: kd.updateService,
},
)
}
func (kd *KubeDNS) setEndpointsStore() {
// Returns a cache.ListWatch that gets all changes to endpoints.
endpointsWatch := kcache.NewListWatchFromClient(kd.kubeClient, "endpoints", kapi.NamespaceAll, kselector.Everything())
kd.endpointsStore, kd.eController = kframework.NewInformer(
endpointsWatch,
&kapi.Endpoints{},
resyncPeriod,
kframework.ResourceEventHandlerFuncs{
AddFunc: kd.handleEndpointAdd,
UpdateFunc: func(oldObj, newObj interface{}) {
// TODO: Avoid unwanted updates.
kd.handleEndpointAdd(newObj)
},
},
)
}
func (kd *KubeDNS) newService(obj interface{}) {
if service, ok := obj.(*kapi.Service); ok {
// if ClusterIP is not set, a DNS entry should not be created
if !kapi.IsServiceIPSet(service) {
kd.newHeadlessService(service)
return
}
if len(service.Spec.Ports) == 0 {
glog.Info("Unexpected service with no ports, this should not have happend: %v", service)
}
kd.newPortalService(service)
}
}
func (kd *KubeDNS) removeService(obj interface{}) {
if s, ok := obj.(*kapi.Service); ok {
subCachePath := append(kd.domainPath, serviceSubdomain, s.Namespace, s.Name)
kd.cache.DeletePath(subCachePath...)
}
}
func (kd *KubeDNS) updateService(oldObj, newObj interface{}) {
kd.newService(newObj)
}
func (kd *KubeDNS) handleEndpointAdd(obj interface{}) {
if e, ok := obj.(*kapi.Endpoints); ok {
kd.addDNSUsingEndpoints(e)
}
}
func (kd *KubeDNS) addDNSUsingEndpoints(e *kapi.Endpoints) error {
svc, err := kd.getServiceFromEndpoints(e)
if err != nil {
return err
}
if svc == nil || kapi.IsServiceIPSet(svc) {
// No headless service found corresponding to endpoints object.
return nil
}
return kd.generateRecordsForHeadlessService(e, svc)
}
func (kd *KubeDNS) getServiceFromEndpoints(e *kapi.Endpoints) (*kapi.Service, error) {
key, err := kcache.MetaNamespaceKeyFunc(e)
if err != nil {
return nil, err
}
obj, exists, err := kd.servicesStore.GetByKey(key)
if err != nil {
return nil, fmt.Errorf("failed to get service object from services store - %v", err)
}
if !exists {
glog.V(1).Infof("could not find service for endpoint %q in namespace %q", e.Name, e.Namespace)
return nil, nil
}
if svc, ok := obj.(*kapi.Service); ok {
return svc, nil
}
return nil, fmt.Errorf("got a non service object in services store %v", obj)
}
func (kd *KubeDNS) newPortalService(service *kapi.Service) {
subCache := NewTreeCache()
recordValue, recordLabel := getSkyMsg(service.Spec.ClusterIP, 0)
subCache.SetEntry(recordLabel, recordValue)
// Generate SRV Records
for i := range service.Spec.Ports {
port := &service.Spec.Ports[i]
if port.Name != "" && port.Protocol != "" {
srvValue := kd.generateSRVRecordValue(service, int(port.Port))
subCache.SetEntry(recordLabel, srvValue, "_"+strings.ToLower(string(port.Protocol)), "_"+port.Name)
}
}
subCachePath := append(kd.domainPath, serviceSubdomain, service.Namespace)
kd.cache.SetSubCache(service.Name, subCache, subCachePath...)
}
func (kd *KubeDNS) generateRecordsForHeadlessService(e *kapi.Endpoints, svc *kapi.Service) error {
// TODO: remove this after v1.4 is released and the old annotations are EOL
podHostnames, err := getPodHostnamesFromAnnotation(e.Annotations)
if err != nil {
return err
}
subCache := NewTreeCache()
glog.V(4).Infof("Endpoints Annotations: %v", e.Annotations)
for idx := range e.Subsets {
for subIdx := range e.Subsets[idx].Addresses {
address := &e.Subsets[idx].Addresses[subIdx]
endpointIP := address.IP
recordValue, endpointName := getSkyMsg(endpointIP, 0)
if hostLabel, exists := getHostname(address, podHostnames); exists {
endpointName = hostLabel
}
subCache.SetEntry(endpointName, recordValue)
for portIdx := range e.Subsets[idx].Ports {
endpointPort := &e.Subsets[idx].Ports[portIdx]
if endpointPort.Name != "" && endpointPort.Protocol != "" {
srvValue := kd.generateSRVRecordValue(svc, int(endpointPort.Port), endpointName)
subCache.SetEntry(endpointName, srvValue, "_"+strings.ToLower(string(endpointPort.Protocol)), "_"+endpointPort.Name)
}
}
}
}
subCachePath := append(kd.domainPath, serviceSubdomain, svc.Namespace)
kd.cache.SetSubCache(svc.Name, subCache, subCachePath...)
return nil
}
func getHostname(address *kapi.EndpointAddress, podHostnames map[string]endpoints.HostRecord) (string, bool) {
if len(address.Hostname) > 0 {
return address.Hostname, true
}
if hostRecord, exists := podHostnames[address.IP]; exists && validation.IsDNS1123Label(hostRecord.HostName) {
return hostRecord.HostName, true
}
return "", false
}
func getPodHostnamesFromAnnotation(annotations map[string]string) (map[string]endpoints.HostRecord, error) {
hostnames := map[string]endpoints.HostRecord{}
if annotations != nil {
if serializedHostnames, exists := annotations[endpoints.PodHostnamesAnnotation]; exists && len(serializedHostnames) > 0 {
err := json.Unmarshal([]byte(serializedHostnames), &hostnames)
if err != nil {
return nil, err
}
}
}
return hostnames, nil
}
func (kd *KubeDNS) generateSRVRecordValue(svc *kapi.Service, portNumber int, cNameLabels ...string) *skymsg.Service {
cName := strings.Join([]string{svc.Name, svc.Namespace, serviceSubdomain, kd.domain}, ".")
for _, cNameLabel := range cNameLabels {
cName = cNameLabel + "." + cName
}
recordValue, _ := getSkyMsg(cName, portNumber)
return recordValue
}
// Generates skydns records for a headless service.
func (kd *KubeDNS) newHeadlessService(service *kapi.Service) error {
// Create an A record for every pod in the service.
// This record must be periodically updated.
// Format is as follows:
// For a service x, with pods a and b create DNS records,
// a.x.ns.domain. and, b.x.ns.domain.
key, err := kcache.MetaNamespaceKeyFunc(service)
if err != nil {
return err
}
e, exists, err := kd.endpointsStore.GetByKey(key)
if err != nil {
return fmt.Errorf("failed to get endpoints object from endpoints store - %v", err)
}
if !exists {
glog.V(1).Infof("Could not find endpoints for service %q in namespace %q. DNS records will be created once endpoints show up.", service.Name, service.Namespace)
return nil
}
if e, ok := e.(*kapi.Endpoints); ok {
return kd.generateRecordsForHeadlessService(e, service)
}
return nil
}
func (kd *KubeDNS) Records(name string, exact bool) ([]skymsg.Service, error) {
glog.Infof("Received DNS Request:%s, exact:%v", name, exact)
trimmed := strings.TrimRight(name, ".")
segments := strings.Split(trimmed, ".")
path := reverseArray(segments)
if kd.isPodRecord(path) {
response, err := kd.getPodRecord(path)
if err == nil {
return []skymsg.Service{*response}, nil
}
return nil, err
}
if exact {
key := path[len(path)-1]
if key == "" {
return []skymsg.Service{}, nil
}
if record, ok := kd.cache.GetEntry(key, path[:len(path)-1]...); ok {
return []skymsg.Service{*(record.(*skymsg.Service))}, nil
}
return nil, etcd.Error{Code: etcd.ErrorCodeKeyNotFound}
}
// tmp, _ := kd.cache.Serialize("")
// glog.Infof("Searching path:%q, %v", path, tmp)
records := kd.cache.GetValuesForPathWithRegex(path...)
retval := []skymsg.Service{}
for _, val := range records {
retval = append(retval, *(val.(*skymsg.Service)))
}
glog.Infof("records:%v, retval:%v, path:%v", records, retval, path)
if len(retval) == 0 {
return nil, etcd.Error{Code: etcd.ErrorCodeKeyNotFound}
}
return retval, nil
}
func (kd *KubeDNS) ReverseRecord(name string) (*skymsg.Service, error) {
glog.Infof("Received ReverseRecord Request:%s", name)
segments := strings.Split(strings.TrimRight(name, "."), ".")
for _, k := range segments {
if k == "*" || k == "any" {
return nil, fmt.Errorf("reverse can not contain wildcards")
}
}
return nil, fmt.Errorf("must be exactly one service record")
}
// e.g {"local", "cluster", "pod", "default", "10-0-0-1"}
func (kd *KubeDNS) isPodRecord(path []string) bool {
if len(path) != len(kd.domainPath)+3 {
return false
}
if path[len(kd.domainPath)] != "pod" {
return false
}
for _, segment := range path {
if segment == "*" {
return false
}
}
return true
}
func (kd *KubeDNS) getPodRecord(path []string) (*skymsg.Service, error) {
ipStr := path[len(path)-1]
ip := strings.Replace(ipStr, "-", ".", -1)
if parsed := net.ParseIP(ip); parsed != nil {
msg := &skymsg.Service{
Host: ip,
Port: 0,
Priority: 10,
Weight: 10,
Ttl: 30,
}
return msg, nil
}
return nil, fmt.Errorf("Invalid IP Address %v", ip)
}
// Returns record in a format that SkyDNS understands.
// Also return the hash of the record.
func getSkyMsg(ip string, port int) (*skymsg.Service, string) {
msg := &skymsg.Service{
Host: ip,
Port: port,
Priority: 10,
Weight: 10,
Ttl: 30,
}
s := fmt.Sprintf("%v", msg)
h := fnv.New32a()
h.Write([]byte(s))
hash := fmt.Sprintf("%x", h.Sum32())
glog.Infof("DNS Record:%s, hash:%s", s, hash)
return msg, fmt.Sprintf("%x", hash)
}
func reverseArray(arr []string) []string {
for i := 0; i < len(arr)/2; i++ {
j := len(arr) - i - 1
arr[i], arr[j] = arr[j], arr[i]
}
return arr
}

377
pkg/dns/dns_test.go Normal file
View File

@ -0,0 +1,377 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dns
import (
"fmt"
"strings"
"testing"
skymsg "github.com/skynetservices/skydns/msg"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
kapi "k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/client/cache"
"net"
)
const (
testDomain = "cluster.local."
basePath = "/skydns/local/cluster"
serviceSubDomain = "svc"
podSubDomain = "pod"
testService = "testservice"
testNamespace = "default"
)
func newKubeDNS() *KubeDNS {
kd := &KubeDNS{
domain: testDomain,
endpointsStore: cache.NewStore(cache.MetaNamespaceKeyFunc),
servicesStore: cache.NewStore(cache.MetaNamespaceKeyFunc),
cache: NewTreeCache(),
domainPath: reverseArray(strings.Split(strings.TrimRight(testDomain, "."), ".")),
}
return kd
}
func TestPodDns(t *testing.T) {
const (
testPodIP = "1.2.3.4"
sanitizedPodIP = "1-2-3-4"
testPodName = "testPod"
)
kd := newKubeDNS()
records, err := kd.Records(sanitizedPodIP+".default.pod."+kd.domain, false)
require.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, testPodIP, records[0].Host)
}
func TestUnnamedSinglePortService(t *testing.T) {
kd := newKubeDNS()
s := newService(testNamespace, testService, "1.2.3.4", "", 80)
// Add the service
kd.newService(s)
assertDNSForClusterIP(t, kd, s)
// Delete the service
kd.removeService(s)
assertNoDNSForClusterIP(t, kd, s)
}
func TestNamedSinglePortService(t *testing.T) {
const (
portName1 = "http1"
portName2 = "http2"
)
kd := newKubeDNS()
s := newService(testNamespace, testService, "1.2.3.4", portName1, 80)
// Add the service
kd.newService(s)
assertDNSForClusterIP(t, kd, s)
assertSRVForNamedPort(t, kd, s, portName1)
newService := *s
// update the portName of the service
newService.Spec.Ports[0].Name = portName2
kd.updateService(s, &newService)
assertDNSForClusterIP(t, kd, s)
assertSRVForNamedPort(t, kd, s, portName2)
assertNoSRVForNamedPort(t, kd, s, portName1)
// Delete the service
kd.removeService(s)
assertNoDNSForClusterIP(t, kd, s)
assertNoSRVForNamedPort(t, kd, s, portName1)
assertNoSRVForNamedPort(t, kd, s, portName2)
}
func TestHeadlessService(t *testing.T) {
kd := newKubeDNS()
s := newHeadlessService()
assert.NoError(t, kd.servicesStore.Add(s))
endpoints := newEndpoints(s, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2"), newSubsetWithOnePort("", 8080, "10.0.0.3", "10.0.0.4"))
assert.NoError(t, kd.endpointsStore.Add(endpoints))
kd.newService(s)
assertDNSForHeadlessService(t, kd, endpoints)
kd.removeService(s)
assertNoDNSForHeadlessService(t, kd, s)
}
func TestHeadlessServiceWithNamedPorts(t *testing.T) {
kd := newKubeDNS()
service := newHeadlessService()
// add service to store
assert.NoError(t, kd.servicesStore.Add(service))
endpoints := newEndpoints(service, newSubsetWithTwoPorts("http1", 80, "http2", 81, "10.0.0.1", "10.0.0.2"),
newSubsetWithOnePort("https", 443, "10.0.0.3", "10.0.0.4"))
// We expect 10 records. 6 SRV records. 4 POD records.
// add endpoints
assert.NoError(t, kd.endpointsStore.Add(endpoints))
// add service
kd.newService(service)
assertDNSForHeadlessService(t, kd, endpoints)
assertSRVForHeadlessService(t, kd, service, endpoints)
// reduce endpoints
endpoints.Subsets = endpoints.Subsets[:1]
kd.handleEndpointAdd(endpoints)
// We expect 6 records. 4 SRV records. 2 POD records.
assertDNSForHeadlessService(t, kd, endpoints)
assertSRVForHeadlessService(t, kd, service, endpoints)
kd.removeService(service)
assertNoDNSForHeadlessService(t, kd, service)
}
func TestHeadlessServiceEndpointsUpdate(t *testing.T) {
kd := newKubeDNS()
service := newHeadlessService()
// add service to store
assert.NoError(t, kd.servicesStore.Add(service))
endpoints := newEndpoints(service, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2"))
// add endpoints to store
assert.NoError(t, kd.endpointsStore.Add(endpoints))
// add service
kd.newService(service)
assertDNSForHeadlessService(t, kd, endpoints)
// increase endpoints
endpoints.Subsets = append(endpoints.Subsets,
newSubsetWithOnePort("", 8080, "10.0.0.3", "10.0.0.4"),
)
// expected DNSRecords = 4
kd.handleEndpointAdd(endpoints)
assertDNSForHeadlessService(t, kd, endpoints)
// remove all endpoints
endpoints.Subsets = []kapi.EndpointSubset{}
kd.handleEndpointAdd(endpoints)
assertNoDNSForHeadlessService(t, kd, service)
// remove service
kd.removeService(service)
assertNoDNSForHeadlessService(t, kd, service)
}
func TestHeadlessServiceWithDelayedEndpointsAddition(t *testing.T) {
kd := newKubeDNS()
// create service
service := newHeadlessService()
// add service to store
assert.NoError(t, kd.servicesStore.Add(service))
// add service
kd.newService(service)
assertNoDNSForHeadlessService(t, kd, service)
// create endpoints
endpoints := newEndpoints(service, newSubsetWithOnePort("", 80, "10.0.0.1", "10.0.0.2"))
// add endpoints to store
assert.NoError(t, kd.endpointsStore.Add(endpoints))
// add endpoints
kd.handleEndpointAdd(endpoints)
assertDNSForHeadlessService(t, kd, endpoints)
// remove service
kd.removeService(service)
assertNoDNSForHeadlessService(t, kd, service)
}
func newService(namespace, serviceName, clusterIP, portName string, portNumber int32) *kapi.Service {
service := kapi.Service{
ObjectMeta: kapi.ObjectMeta{
Name: serviceName,
Namespace: namespace,
},
Spec: kapi.ServiceSpec{
ClusterIP: clusterIP,
Ports: []kapi.ServicePort{
{Port: portNumber, Name: portName, Protocol: "TCP"},
},
},
}
return &service
}
func newHeadlessService() *kapi.Service {
service := kapi.Service{
ObjectMeta: kapi.ObjectMeta{
Name: testService,
Namespace: testNamespace,
},
Spec: kapi.ServiceSpec{
ClusterIP: "None",
Ports: []kapi.ServicePort{
{Port: 0},
},
},
}
return &service
}
func newEndpoints(service *kapi.Service, subsets ...kapi.EndpointSubset) *kapi.Endpoints {
endpoints := kapi.Endpoints{
ObjectMeta: service.ObjectMeta,
Subsets: []kapi.EndpointSubset{},
}
for _, subset := range subsets {
endpoints.Subsets = append(endpoints.Subsets, subset)
}
return &endpoints
}
func newSubsetWithOnePort(portName string, port int32, ips ...string) kapi.EndpointSubset {
subset := newSubset()
subset.Ports = append(subset.Ports, kapi.EndpointPort{Port: port, Name: portName, Protocol: "TCP"})
for _, ip := range ips {
subset.Addresses = append(subset.Addresses, kapi.EndpointAddress{IP: ip})
}
return subset
}
func newSubsetWithTwoPorts(portName1 string, portNumber1 int32, portName2 string, portNumber2 int32, ips ...string) kapi.EndpointSubset {
subset := newSubsetWithOnePort(portName1, portNumber1, ips...)
subset.Ports = append(subset.Ports, kapi.EndpointPort{Port: portNumber2, Name: portName2, Protocol: "TCP"})
return subset
}
func newSubset() kapi.EndpointSubset {
subset := kapi.EndpointSubset{
Addresses: []kapi.EndpointAddress{},
Ports: []kapi.EndpointPort{},
}
return subset
}
func assertSRVForHeadlessService(t *testing.T, kd *KubeDNS, s *kapi.Service, e *kapi.Endpoints) {
for _, subset := range e.Subsets {
for _, port := range subset.Ports {
records, err := kd.Records(getSRVFQDN(kd, s, port.Name), false)
require.NoError(t, err)
assertRecordPortsMatchPort(t, port.Port, records)
assertCNameRecordsMatchEndpointIPs(t, kd, subset.Addresses, records)
}
}
}
func assertDNSForHeadlessService(t *testing.T, kd *KubeDNS, e *kapi.Endpoints) {
records, err := kd.Records(getEndpointsFQDN(kd, e), false)
require.NoError(t, err)
endpoints := map[string]bool{}
for _, subset := range e.Subsets {
for _, endpointAddress := range subset.Addresses {
endpoints[endpointAddress.IP] = true
}
}
assert.Equal(t, len(endpoints), len(records))
for _, record := range records {
_, found := endpoints[record.Host]
assert.True(t, found)
}
}
func assertRecordPortsMatchPort(t *testing.T, port int32, records []skymsg.Service) {
for _, record := range records {
assert.Equal(t, port, int32(record.Port))
}
}
func assertCNameRecordsMatchEndpointIPs(t *testing.T, kd *KubeDNS, e []kapi.EndpointAddress, records []skymsg.Service) {
endpoints := map[string]bool{}
for _, endpointAddress := range e {
endpoints[endpointAddress.IP] = true
}
assert.Equal(t, len(e), len(records), "unexpected record count")
for _, record := range records {
_, found := endpoints[getIPForCName(t, kd, record.Host)]
assert.True(t, found, "Did not endpoint with address:%s", record.Host)
}
}
func getIPForCName(t *testing.T, kd *KubeDNS, cname string) string {
records, err := kd.Records(cname, false)
require.NoError(t, err)
assert.Equal(t, 1, len(records), "Could not get IP for CNAME record for %s", cname)
assert.NotNil(t, net.ParseIP(records[0].Host), "Invalid IP address %q", records[0].Host)
return records[0].Host
}
func assertNoDNSForHeadlessService(t *testing.T, kd *KubeDNS, s *kapi.Service) {
records, err := kd.Records(getServiceFQDN(kd, s), false)
require.Error(t, err)
assert.Equal(t, 0, len(records))
}
func assertSRVForNamedPort(t *testing.T, kd *KubeDNS, s *kapi.Service, portName string) {
records, err := kd.Records(getSRVFQDN(kd, s, portName), false)
require.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, getServiceFQDN(kd, s), records[0].Host)
}
func assertNoSRVForNamedPort(t *testing.T, kd *KubeDNS, s *kapi.Service, portName string) {
records, err := kd.Records(getSRVFQDN(kd, s, portName), false)
require.Error(t, err)
assert.Equal(t, 0, len(records))
}
func assertNoDNSForClusterIP(t *testing.T, kd *KubeDNS, s *kapi.Service) {
records, err := kd.Records(getServiceFQDN(kd, s), false)
require.Error(t, err)
assert.Equal(t, 0, len(records))
}
func assertDNSForClusterIP(t *testing.T, kd *KubeDNS, s *kapi.Service) {
serviceFQDN := getServiceFQDN(kd, s)
queries := []string{
serviceFQDN,
strings.Replace(serviceFQDN, ".svc.", ".*.", 1),
strings.Replace(serviceFQDN, s.Namespace, "*", 1),
strings.Replace(strings.Replace(serviceFQDN, s.Namespace, "*", 1), ".svc.", ".*.", 1),
"*." + serviceFQDN,
}
for _, query := range queries {
records, err := kd.Records(query, false)
require.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, s.Spec.ClusterIP, records[0].Host)
}
}
func getServiceFQDN(kd *KubeDNS, s *kapi.Service) string {
return fmt.Sprintf("%s.%s.svc.%s", s.Name, s.Namespace, kd.domain)
}
func getEndpointsFQDN(kd *KubeDNS, e *kapi.Endpoints) string {
return fmt.Sprintf("%s.%s.svc.%s", e.ObjectMeta.Name, e.ObjectMeta.Namespace, kd.domain)
}
func getSRVFQDN(kd *KubeDNS, s *kapi.Service, portName string) string {
return fmt.Sprintf("_%s._tcp.%s.%s.svc.%s", portName, s.Name, s.Namespace, kd.domain)
}

312
pkg/dns/treecache.go Normal file
View File

@ -0,0 +1,312 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package dns
import (
"bytes"
"crypto/md5"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"reflect"
"strings"
"sync"
)
const (
dataFile = "data.dat"
crcFile = "data.crc"
)
type object interface{}
type TreeCache struct {
ChildNodes map[string]*TreeCache
Entries map[string]interface{}
m *sync.RWMutex
}
func NewTreeCache() *TreeCache {
return &TreeCache{
ChildNodes: make(map[string]*TreeCache),
Entries: make(map[string]interface{}),
m: &sync.RWMutex{},
}
}
func Deserialize(dir string) (*TreeCache, error) {
b, err := ioutil.ReadFile(path.Join(dir, dataFile))
if err != nil {
return nil, err
}
var hash []byte
hash, err = ioutil.ReadFile(path.Join(dir, crcFile))
if err != nil {
return nil, err
}
if !reflect.DeepEqual(hash, getMD5(b)) {
return nil, fmt.Errorf("Checksum failed")
}
var cache TreeCache
err = json.Unmarshal(b, &cache)
if err != nil {
return nil, err
}
cache.m = &sync.RWMutex{}
return &cache, nil
}
func (cache *TreeCache) Serialize(dir string) (string, error) {
cache.m.RLock()
defer cache.m.RUnlock()
b, err := json.Marshal(cache)
if err != nil {
return "", err
}
if len(dir) == 0 {
var prettyJSON bytes.Buffer
err = json.Indent(&prettyJSON, b, "", "\t")
if err != nil {
return "", err
}
return string(prettyJSON.Bytes()), nil
}
if err := ensureDir(dir, os.FileMode(0755)); err != nil {
return "", err
}
if err := ioutil.WriteFile(path.Join(dir, dataFile), b, 0644); err != nil {
return "", err
}
if err := ioutil.WriteFile(path.Join(dir, crcFile), getMD5(b), 0644); err != nil {
return "", err
}
return string(b), nil
}
func (cache *TreeCache) SetEntry(key string, val interface{}, path ...string) {
cache.m.Lock()
defer cache.m.Unlock()
node := cache.ensureChildNode(path...)
node.Entries[key] = val
}
func (cache *TreeCache) ReplaceEntries(entries map[string]interface{}, path ...string) {
cache.m.Lock()
defer cache.m.Unlock()
node := cache.ensureChildNode(path...)
node.Entries = make(map[string]interface{})
for key, val := range entries {
node.Entries[key] = val
}
}
func (cache *TreeCache) GetSubCache(path ...string) *TreeCache {
childCache := cache
for _, subpath := range path {
childCache = childCache.ChildNodes[subpath]
if childCache == nil {
return childCache
}
}
return childCache
}
func (cache *TreeCache) SetSubCache(key string, subCache *TreeCache, path ...string) {
cache.m.Lock()
defer cache.m.Unlock()
node := cache.ensureChildNode(path...)
node.ChildNodes[key] = subCache
}
func (cache *TreeCache) GetEntry(key string, path ...string) (interface{}, bool) {
cache.m.RLock()
defer cache.m.RUnlock()
childNode := cache.GetSubCache(path...)
val, ok := childNode.Entries[key]
return val, ok
}
func (cache *TreeCache) GetValuesForPathWithRegex(path ...string) []interface{} {
cache.m.RLock()
defer cache.m.RUnlock()
retval := []interface{}{}
nodesToExplore := []*TreeCache{cache}
for idx, subpath := range path {
nextNodesToExplore := []*TreeCache{}
if idx == len(path)-1 {
// if path ends on an entry, instead of a child node, add the entry
for _, node := range nodesToExplore {
if subpath == "*" || subpath == "any" {
nextNodesToExplore = append(nextNodesToExplore, node)
} else {
if val, ok := node.Entries[subpath]; ok {
retval = append(retval, val)
} else {
childNode := node.ChildNodes[subpath]
if childNode != nil {
nextNodesToExplore = append(nextNodesToExplore, childNode)
}
}
}
}
nodesToExplore = nextNodesToExplore
break
}
if subpath == "*" || subpath == "any" {
for _, node := range nodesToExplore {
for subkey, subnode := range node.ChildNodes {
if !strings.HasPrefix(subkey, "_") {
nextNodesToExplore = append(nextNodesToExplore, subnode)
}
}
}
} else {
for _, node := range nodesToExplore {
childNode := node.ChildNodes[subpath]
if childNode != nil {
nextNodesToExplore = append(nextNodesToExplore, childNode)
}
}
}
nodesToExplore = nextNodesToExplore
}
for _, node := range nodesToExplore {
for _, val := range node.Entries {
retval = append(retval, val)
}
}
return retval
}
func (cache *TreeCache) GetEntries(recursive bool, path ...string) []interface{} {
cache.m.RLock()
defer cache.m.RUnlock()
childNode := cache.GetSubCache(path...)
if childNode == nil {
return nil
}
retval := [][]interface{}{{}}
childNode.appendValues(recursive, retval)
return retval[0]
}
func (cache *TreeCache) DeletePath(path ...string) bool {
if len(path) == 0 {
return false
}
cache.m.Lock()
defer cache.m.Unlock()
if parentNode := cache.GetSubCache(path[:len(path)-1]...); parentNode != nil {
if _, ok := parentNode.ChildNodes[path[len(path)-1]]; ok {
delete(parentNode.ChildNodes, path[len(path)-1])
return true
}
}
return false
}
func (tn *TreeCache) DeleteEntry(key string, path ...string) bool {
tn.m.Lock()
defer tn.m.Unlock()
childNode := tn.GetSubCache(path...)
if childNode == nil {
return false
}
if _, ok := childNode.Entries[key]; ok {
delete(childNode.Entries, key)
return true
}
return false
}
func (tn *TreeCache) appendValues(recursive bool, ref [][]interface{}) {
for _, value := range tn.Entries {
ref[0] = append(ref[0], value)
}
if recursive {
for _, node := range tn.ChildNodes {
node.appendValues(recursive, ref)
}
}
}
func (tn *TreeCache) ensureChildNode(path ...string) *TreeCache {
childNode := tn
for _, subpath := range path {
newNode := childNode.ChildNodes[subpath]
if newNode == nil {
newNode = NewTreeCache()
childNode.ChildNodes[subpath] = newNode
}
childNode = newNode
}
return childNode
}
func ensureDir(path string, perm os.FileMode) error {
s, err := os.Stat(path)
if err != nil || !s.IsDir() {
return os.Mkdir(path, perm)
}
return nil
}
func getMD5(b []byte) []byte {
h := md5.New()
h.Write(b)
return []byte(fmt.Sprintf("%x", h.Sum(nil)))
}
func main() {
root := NewTreeCache()
fmt.Println("Adding Entries")
root.SetEntry("k", "v")
root.SetEntry("foo", "bar", "local")
root.SetEntry("foo1", "bar1", "local", "cluster")
fmt.Println("Fetching Entries")
for _, entry := range root.GetEntries(true, "local") {
fmt.Printf("%s\n", entry)
}
fmt.Println("Serializing")
if _, err := root.Serialize("./foo"); err != nil {
fmt.Printf("Serialization Error: %v,\n", err)
return
}
fmt.Println("Deserializing")
tn, err := Deserialize("./foo")
if err != nil {
fmt.Printf("Deserialization Error: %v\n", err)
return
}
fmt.Println("Fetching Entries")
for _, entry := range tn.GetEntries(true, "local") {
fmt.Printf("%s\n", entry)
}
}