Load input object for increased safety

This commit is contained in:
kargakis 2015-05-07 17:30:28 +02:00
parent ca0f678b9a
commit 93dc839b19
5 changed files with 221 additions and 66 deletions

View File

@ -171,6 +171,10 @@ func NewAPIFactory() (*cmdutil.Factory, *testFactory, runtime.Codec) {
t := &testFactory{
Validator: validation.NullSchema{},
}
generators := map[string]kubectl.Generator{
"run-container/v1": kubectl.BasicReplicationController{},
"service/v1": kubectl.ServiceGenerator{},
}
return &cmdutil.Factory{
Object: func() (meta.RESTMapper, runtime.ObjectTyper) {
return latest.RESTMapper, api.Scheme
@ -200,6 +204,10 @@ func NewAPIFactory() (*cmdutil.Factory, *testFactory, runtime.Codec) {
ClientConfig: func() (*client.Config, error) {
return t.ClientConfig, t.Err
},
Generator: func(name string) (kubectl.Generator, bool) {
generator, ok := generators[name]
return generator, ok
},
}, t, testapi.Codec()
}

View File

@ -86,8 +86,28 @@ func RunExpose(f *cmdutil.Factory, out io.Writer, cmd *cobra.Command, args []str
return err
}
generatorName := cmdutil.GetFlagString(cmd, "generator")
// Get the input object
var mapping *meta.RESTMapping
mapper, typer := f.Object()
v, k, err := mapper.VersionAndKindForResource(res)
if err != nil {
return err
}
mapping, err = mapper.RESTMapping(k, v)
if err != nil {
return err
}
client, err := f.RESTClient(mapping)
if err != nil {
return err
}
inputObject, err := resource.NewHelper(client, mapping).Get(namespace, name)
if err != nil {
return err
}
// Get the generator, setup and validate all required parameters
generatorName := cmdutil.GetFlagString(cmd, "generator")
generator, found := f.Generator(generatorName)
if !found {
return cmdutil.UsageError(cmd, fmt.Sprintf("generator %q not found.", generator))
@ -99,19 +119,9 @@ func RunExpose(f *cmdutil.Factory, out io.Writer, cmd *cobra.Command, args []str
} else {
params["name"] = cmdutil.GetFlagString(cmd, "service-name")
}
var mapping *meta.RESTMapping
if s, found := params["selector"]; !found || len(s) == 0 || cmdutil.GetFlagInt(cmd, "port") < 1 {
mapper, _ := f.Object()
v, k, err := mapper.VersionAndKindForResource(res)
if err != nil {
return err
}
mapping, err = mapper.RESTMapping(k, v)
if err != nil {
return err
}
if len(s) == 0 {
s, err := f.PodSelectorForResource(mapping, namespace, name)
s, err := f.PodSelectorForObject(inputObject)
if err != nil {
return err
}
@ -125,7 +135,7 @@ func RunExpose(f *cmdutil.Factory, out io.Writer, cmd *cobra.Command, args []str
}
}
if cmdutil.GetFlagInt(cmd, "port") < 0 && !noPorts {
ports, err := f.PortsForResource(mapping, namespace, name)
ports, err := f.PortsForObject(inputObject)
if err != nil {
return err
}
@ -141,12 +151,12 @@ func RunExpose(f *cmdutil.Factory, out io.Writer, cmd *cobra.Command, args []str
if cmdutil.GetFlagBool(cmd, "create-external-load-balancer") {
params["create-external-load-balancer"] = "true"
}
err = kubectl.ValidateParams(names, params)
if err != nil {
return err
}
// Expose new object
object, err := generator.Generate(params)
if err != nil {
return err
@ -162,7 +172,6 @@ func RunExpose(f *cmdutil.Factory, out io.Writer, cmd *cobra.Command, args []str
// TODO: extract this flag to a central location, when such a location exists.
if !cmdutil.GetFlagBool(cmd, "dry-run") {
mapper, typer := f.Object()
resourceMapper := &resource.Mapper{typer, mapper, f.ClientMapperForCommand()}
info, err := resourceMapper.InfoForObject(object)
if err != nil {

View File

@ -0,0 +1,98 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"bytes"
"net/http"
"strings"
"testing"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
)
func TestRunExposeService(t *testing.T) {
tests := []struct {
name string
args []string
input runtime.Object
flags map[string]string
output runtime.Object
status int
}{
{
name: "expose-service-from-service",
args: []string{"service", "baz"},
input: &api.Service{
ObjectMeta: api.ObjectMeta{Name: "baz", Namespace: "test", ResourceVersion: "12"},
Spec: api.ServiceSpec{
Selector: map[string]string{"app": "go"},
},
},
flags: map[string]string{"selector": "func=stream", "protocol": "UDP", "port": "14", "service-name": "foo"},
output: &api.Service{
ObjectMeta: api.ObjectMeta{Name: "foo", Namespace: "test", ResourceVersion: "12"},
Spec: api.ServiceSpec{
Ports: []api.ServicePort{
{
Name: "default",
Protocol: api.Protocol("UDP"),
Port: 14,
},
},
Selector: map[string]string{"func": "stream"},
},
},
status: 200,
},
}
for _, test := range tests {
f, tf, codec := NewAPIFactory()
tf.Printer = &testPrinter{}
tf.Client = &client.FakeRESTClient{
Codec: codec,
Client: client.HTTPClientFunc(func(req *http.Request) (*http.Response, error) {
switch p, m := req.URL.Path, req.Method; {
case p == "/namespaces/test/services/baz" && m == "GET":
return &http.Response{StatusCode: test.status, Body: objBody(codec, test.input)}, nil
case p == "/namespaces/test/services" && m == "POST":
return &http.Response{StatusCode: test.status, Body: objBody(codec, test.output)}, nil
default:
t.Fatalf("unexpected request: %#v\n%#v", req.URL, req)
return nil, nil
}
}),
}
tf.Namespace = "test"
buf := bytes.NewBuffer([]byte{})
cmd := NewCmdExposeService(f, buf)
cmd.SetOutput(buf)
for flag, value := range test.flags {
cmd.Flags().Set(flag, value)
}
cmd.Run(cmd, test.args)
out := buf.String()
if strings.Contains(out, "services/foo") {
t.Errorf("%s: unexpected output: %s", test.name, out)
}
}
}

View File

@ -69,11 +69,10 @@ type Factory struct {
Resizer func(mapping *meta.RESTMapping) (kubectl.Resizer, error)
// Returns a Reaper for gracefully shutting down resources.
Reaper func(mapping *meta.RESTMapping) (kubectl.Reaper, error)
// PodSelectorForResource returns the pod selector associated with the provided resource name
// or an error.
PodSelectorForResource func(mapping *meta.RESTMapping, namespace, name string) (string, error)
// PortForResource returns the ports associated with the provided resource name or an error
PortsForResource func(mapping *meta.RESTMapping, namespace, name string) ([]string, error)
// PodSelectorForObject returns the pod selector associated with the provided object
PodSelectorForObject func(object runtime.Object) (string, error)
// PortsForObject returns the ports associated with the provided object
PortsForObject func(object runtime.Object) ([]string, error)
// Returns a schema that can validate objects stored on disk.
Validator func() (validation.Schema, error)
// Returns the default namespace to use in cases where no other namespace is specified
@ -145,62 +144,42 @@ func NewFactory(optionalClientConfig clientcmd.ClientConfig) *Factory {
Printer: func(mapping *meta.RESTMapping, noHeaders bool) (kubectl.ResourcePrinter, error) {
return kubectl.NewHumanReadablePrinter(noHeaders), nil
},
PodSelectorForResource: func(mapping *meta.RESTMapping, namespace, name string) (string, error) {
PodSelectorForObject: func(object runtime.Object) (string, error) {
// TODO: replace with a swagger schema based approach (identify pod selector via schema introspection)
client, err := clients.ClientForVersion("")
if err != nil {
return "", err
}
switch mapping.Kind {
case "ReplicationController":
rc, err := client.ReplicationControllers(namespace).Get(name)
if err != nil {
return "", err
}
return kubectl.MakeLabels(rc.Spec.Selector), nil
case "Pod":
pod, err := client.Pods(namespace).Get(name)
if err != nil {
return "", err
}
if len(pod.Labels) == 0 {
switch t := object.(type) {
case *api.ReplicationController:
return kubectl.MakeLabels(t.Spec.Selector), nil
case *api.Pod:
if len(t.Labels) == 0 {
return "", fmt.Errorf("the pod has no labels and cannot be exposed")
}
return kubectl.MakeLabels(pod.Labels), nil
case "Service":
svc, err := client.Services(namespace).Get(name)
return kubectl.MakeLabels(t.Labels), nil
case *api.Service:
if t.Spec.Selector == nil {
return "", fmt.Errorf("the service has no pod selector set")
}
return kubectl.MakeLabels(t.Spec.Selector), nil
default:
kind, err := meta.NewAccessor().Kind(object)
if err != nil {
return "", err
}
if svc.Spec.Selector == nil {
return "", fmt.Errorf("the service has no pod selector set")
}
return kubectl.MakeLabels(svc.Spec.Selector), nil
default:
return "", fmt.Errorf("it is not possible to get a pod selector from %s", mapping.Kind)
return "", fmt.Errorf("it is not possible to get a pod selector from %s", kind)
}
},
PortsForResource: func(mapping *meta.RESTMapping, namespace, name string) ([]string, error) {
PortsForObject: func(object runtime.Object) ([]string, error) {
// TODO: replace with a swagger schema based approach (identify pod selector via schema introspection)
client, err := clients.ClientForVersion("")
if err != nil {
return nil, err
}
switch mapping.Kind {
case "ReplicationController":
rc, err := client.ReplicationControllers(namespace).Get(name)
if err != nil {
return nil, err
}
return getPorts(rc.Spec.Template.Spec), nil
case "Pod":
pod, err := client.Pods(namespace).Get(name)
if err != nil {
return nil, err
}
return getPorts(pod.Spec), nil
switch t := object.(type) {
case *api.ReplicationController:
return getPorts(t.Spec.Template.Spec), nil
case *api.Pod:
return getPorts(t.Spec), nil
default:
return nil, fmt.Errorf("it is not possible to get ports from %s", mapping.Kind)
kind, err := meta.NewAccessor().Kind(object)
if err != nil {
return nil, err
}
return nil, fmt.Errorf("it is not possible to get ports from %s", kind)
}
},
Resizer: func(mapping *meta.RESTMapping) (kubectl.Resizer, error) {

View File

@ -17,8 +17,10 @@ limitations under the License.
package util
import (
"sort"
"testing"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client/clientcmd"
clientcmdapi "github.com/GoogleCloudPlatform/kubernetes/pkg/client/clientcmd/api"
)
@ -39,3 +41,62 @@ func TestNewFactoryNoFlagBindings(t *testing.T) {
t.Errorf("Expected zero flags, but got %v", factory.flags)
}
}
func TestPodSelectorForObject(t *testing.T) {
f := NewFactory(nil)
svc := &api.Service{
ObjectMeta: api.ObjectMeta{Name: "baz", Namespace: "test"},
Spec: api.ServiceSpec{
Selector: map[string]string{
"foo": "bar",
},
},
}
expected := "foo=bar"
got, err := f.PodSelectorForObject(svc)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if expected != got {
t.Fatalf("Selector mismatch! Expected %s, got %s", expected, got)
}
}
func TestPortsForObject(t *testing.T) {
f := NewFactory(nil)
pod := &api.Pod{
ObjectMeta: api.ObjectMeta{Name: "baz", Namespace: "test", ResourceVersion: "12"},
Spec: api.PodSpec{
Containers: []api.Container{
{
Ports: []api.ContainerPort{
{
ContainerPort: 101,
},
},
},
},
},
}
expected := []string{"101"}
got, err := f.PortsForObject(pod)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(expected) != len(got) {
t.Fatalf("Ports size mismatch! Expected %d, got %d", len(expected), len(got))
}
sort.Strings(expected)
sort.Strings(got)
for i, port := range got {
if port != expected[i] {
t.Fatalf("Port mismatch! Expected %s, got %s", expected[i], port)
}
}
}