Client should validate the incoming host value

Convert host:port and URLs passed to client.New() into the proper
values, and return an error if the value is invalid.  Change CLI
to return an error if -master is invalid.  Remove Client.rawRequest
which was not in use, and fix the involved tests. Add NewOrDie

Preserves the behavior of the client to not auth when a non-https
URL is passed (although in the future this should be corrected).
This commit is contained in:
Clayton Coleman
2014-08-28 09:56:38 -04:00
parent fa17697194
commit 818f357128
15 changed files with 203 additions and 172 deletions

View File

@@ -20,7 +20,6 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
@@ -30,7 +29,6 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/version"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
"github.com/golang/glog"
)
// Interface holds the methods for clients of Kubernetes,
@@ -88,6 +86,28 @@ type Client struct {
*RESTClient
}
// New creates a Kubernetes client. This client works with pods, replication controllers
// and services. It allows operations such as list, get, update and delete on these objects.
// host must be a host string, a host:port combo, or an http or https URL. Passing a prefix
// to a URL will prepend the server path. Returns an error if host cannot be converted to a
// valid URL.
func New(host string, auth *AuthInfo) (*Client, error) {
restClient, err := NewRESTClient(host, auth, "/api/v1beta1/")
if err != nil {
return nil, err
}
return &Client{restClient}, nil
}
// NewOrDie creates a Kubernetes client and panics if the provided host is invalid.
func NewOrDie(host string, auth *AuthInfo) *Client {
client, err := New(host, auth)
if err != nil {
panic(err)
}
return client
}
// StatusErr might get returned from an api call if your request is still being processed
// and hence the expected return data is not available yet.
type StatusErr struct {
@@ -109,20 +129,31 @@ type AuthInfo struct {
// Host is the http://... base for the URL
type RESTClient struct {
host string
prefix string
secure bool
auth *AuthInfo
httpClient *http.Client
Sync bool
PollPeriod time.Duration
Timeout time.Duration
Prefix string
}
// NewRESTClient creates a new RESTClient. This client performs generic REST functions
// such as Get, Put, Post, and Delete on specified paths.
func NewRESTClient(host string, auth *AuthInfo, prefix string) *RESTClient {
func NewRESTClient(host string, auth *AuthInfo, path string) (*RESTClient, error) {
prefix, err := normalizePrefix(host, path)
if err != nil {
return nil, err
}
base := *prefix
base.Path = ""
base.RawQuery = ""
base.Fragment = ""
return &RESTClient{
auth: auth,
host: host,
host: base.String(),
prefix: prefix.Path,
secure: prefix.Scheme == "https",
auth: auth,
httpClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
@@ -133,15 +164,36 @@ func NewRESTClient(host string, auth *AuthInfo, prefix string) *RESTClient {
Sync: false,
PollPeriod: time.Second * 2,
Timeout: time.Second * 20,
Prefix: prefix,
}
}, nil
}
// New creates a Kubernetes client. This client works with pods, replication controllers
// and services. It allows operations such as list, get, update and delete on these objects.
func New(host string, auth *AuthInfo) *Client {
return &Client{NewRESTClient(host, auth, "/api/v1beta1/")}
// normalizePrefix ensures the passed initial value is valid
func normalizePrefix(host, prefix string) (*url.URL, error) {
if host == "" {
return nil, fmt.Errorf("host must be a URL or a host:port pair")
}
base := host
hostURL, err := url.Parse(base)
if err != nil {
return nil, err
}
if hostURL.Scheme == "" {
hostURL, err = url.Parse("http://" + base)
if err != nil {
return nil, err
}
if hostURL.Path != "" && hostURL.Path != "/" {
return nil, fmt.Errorf("host must be a URL or a host:port pair: %s", base)
}
}
hostURL.Path += prefix
return hostURL, nil
}
// Secure returns true if the client is configured for secure connections.
func (c *RESTClient) Secure() bool {
return c.secure
}
// Execute a request, adds authentication (if auth != nil), and HTTPS cert ignoring.
@@ -186,55 +238,6 @@ func (c *RESTClient) doRequest(request *http.Request) ([]byte, error) {
return body, err
}
// Underlying base implementation of performing a request.
// method is the HTTP method (e.g. "GET")
// path is the path on the host to hit
// requestBody is the body of the request. Can be nil.
// target the interface to marshal the JSON response into. Can be nil.
func (c *RESTClient) rawRequest(method, path string, requestBody io.Reader, target interface{}) ([]byte, error) {
reqUrl, err := c.makeURL(path)
if err != nil {
return nil, err
}
request, err := http.NewRequest(method, reqUrl, requestBody)
if err != nil {
return nil, err
}
body, err := c.doRequest(request)
if err != nil {
return body, err
}
if target != nil {
err = api.DecodeInto(body, target)
}
if err != nil {
glog.Infof("Failed to parse: %s\n", string(body))
// FIXME: no need to return err here?
}
return body, err
}
func (c *RESTClient) makeURL(path string) (string, error) {
base := c.host
hostURL, err := url.Parse(base)
if err != nil {
return "", err
}
if hostURL.Scheme == "" {
hostURL, err = url.Parse("http://" + base)
if err != nil {
return "", err
}
if hostURL.Path != "" && hostURL.Path != "/" {
return "", fmt.Errorf("host must be a URL or a host:port pair: %s", base)
}
}
hostURL.Path += c.Prefix + path
return hostURL.String(), nil
}
// ListPods takes a selector, and returns the list of pods that match that selector
func (c *Client) ListPods(selector labels.Selector) (result api.PodList, err error) {
err = c.Get().Path("pods").SelectorParam("labels", selector).Do().Into(&result)

View File

@@ -21,6 +21,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"path"
"reflect"
"testing"
@@ -31,24 +32,22 @@ import (
)
// TODO: Move this to a common place, it's needed in multiple tests.
var apiPath = "/api/v1beta1"
func makeURL(suffix string) string {
return apiPath + suffix
}
const apiPath = "/api/v1beta1"
func TestValidatesHostParameter(t *testing.T) {
testCases := map[string]struct {
Value string
Err bool
Host string
Prefix string
Err bool
}{
"foo.bar.com": {"http://foo.bar.com/api/v1beta1/", false},
"http://host/server": {"http://host/server/api/v1beta1/", false},
"host/server": {"", true},
"127.0.0.1": {"http://127.0.0.1", "/api/v1beta1/", false},
"127.0.0.1:8080": {"http://127.0.0.1:8080", "/api/v1beta1/", false},
"foo.bar.com": {"http://foo.bar.com", "/api/v1beta1/", false},
"http://host/server": {"http://host", "/server/api/v1beta1/", false},
"host/server": {"", "", true},
}
for k, expected := range testCases {
c := RESTClient{host: k, Prefix: "/api/v1beta1/"}
actual, err := c.makeURL("")
c, err := NewRESTClient(k, nil, "/api/v1beta1/")
switch {
case err == nil && expected.Err:
t.Errorf("expected error but was nil")
@@ -56,9 +55,15 @@ func TestValidatesHostParameter(t *testing.T) {
case err != nil && !expected.Err:
t.Errorf("unexpected error %v", err)
continue
case err != nil:
continue
}
if expected.Value != actual {
t.Errorf("%s: expected %s, got %s", k, expected.Value, actual)
if e, a := expected.Host, c.host; e != a {
t.Errorf("%s: expected host %s, got %s", k, e, a)
continue
}
if e, a := expected.Prefix, c.prefix; e != a {
t.Errorf("%s: expected prefix %s, got %s", k, e, a)
continue
}
}
@@ -349,9 +354,10 @@ func (c *testClient) Setup() *testClient {
}
c.server = httptest.NewServer(c.handler)
if c.Client == nil {
c.Client = New("", nil)
c.Client = NewOrDie("localhost", nil)
}
c.Client.host = c.server.URL
c.Client.prefix = "/api/v1beta1/"
c.QueryValidator = map[string]func(string, string) bool{}
return c
}
@@ -374,7 +380,7 @@ func (c *testClient) Validate(t *testing.T, received interface{}, err error) {
// We check the query manually, so blank it out so that FakeHandler.ValidateRequest
// won't check it.
c.handler.RequestReceived.URL.RawQuery = ""
c.handler.ValidateRequest(t, makeURL(c.Request.Path), c.Request.Method, requestBody)
c.handler.ValidateRequest(t, path.Join(apiPath, c.Request.Path), c.Request.Method, requestBody)
for key, values := range c.Request.Query {
validator, ok := c.QueryValidator[key]
if !ok {
@@ -437,18 +443,27 @@ func TestDeleteService(t *testing.T) {
c.Validate(t, nil, err)
}
func TestMakeRequest(t *testing.T) {
func TestDoRequest(t *testing.T) {
invalid := "aaaaa"
testClients := []testClient{
{Request: testRequest{Method: "GET", Path: "/good"}, Response: Response{StatusCode: 200}},
{Request: testRequest{Method: "GET", Path: "/bad%ZZ"}, Error: true},
{Client: New("", &AuthInfo{"foo", "bar"}), Request: testRequest{Method: "GET", Path: "/auth", Header: "Authorization"}, Response: Response{StatusCode: 200}},
{Client: &Client{&RESTClient{httpClient: http.DefaultClient}}, Request: testRequest{Method: "GET", Path: "/nocertificate"}, Error: true},
{Request: testRequest{Method: "GET", Path: "/error"}, Response: Response{StatusCode: 500}, Error: true},
{Request: testRequest{Method: "POST", Path: "/faildecode"}, Response: Response{StatusCode: 200, Body: "aaaaa"}, Target: &struct{}{}, Error: true},
{Request: testRequest{Method: "GET", Path: "/failread"}, Response: Response{StatusCode: 200, Body: "aaaaa"}, Target: &struct{}{}, Error: true},
{Request: testRequest{Method: "GET", Path: "good"}, Response: Response{StatusCode: 200}},
{Request: testRequest{Method: "GET", Path: "bad%ZZ"}, Error: true},
{Client: NewOrDie("localhost", &AuthInfo{"foo", "bar"}), Request: testRequest{Method: "GET", Path: "auth", Header: "Authorization"}, Response: Response{StatusCode: 200}},
{Client: &Client{&RESTClient{httpClient: http.DefaultClient}}, Request: testRequest{Method: "GET", Path: "nocertificate"}, Error: true},
{Request: testRequest{Method: "GET", Path: "error"}, Response: Response{StatusCode: 500}, Error: true},
{Request: testRequest{Method: "POST", Path: "faildecode"}, Response: Response{StatusCode: 200, RawBody: &invalid}, Target: &struct{}{}},
{Request: testRequest{Method: "GET", Path: "failread"}, Response: Response{StatusCode: 200, RawBody: &invalid}, Target: &struct{}{}},
}
for _, c := range testClients {
response, err := c.Setup().rawRequest(c.Request.Method, c.Request.Path[1:], nil, c.Target)
client := c.Setup()
prefix, _ := url.Parse(client.host)
prefix.Path = client.prefix + c.Request.Path
request := &http.Request{
Method: c.Request.Method,
Header: make(http.Header),
URL: prefix,
}
response, err := client.doRequest(request)
c.Validate(t, response, err)
}
}
@@ -464,7 +479,10 @@ func TestDoRequestAccepted(t *testing.T) {
testServer := httptest.NewServer(&fakeHandler)
request, _ := http.NewRequest("GET", testServer.URL+"/foo/bar", nil)
auth := AuthInfo{User: "user", Password: "pass"}
c := New(testServer.URL, &auth)
c, err := New(testServer.URL, &auth)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := c.doRequest(request)
if request.Header["Authorization"] == nil {
t.Errorf("Request is missing authorization header: %#v", *request)
@@ -498,7 +516,10 @@ func TestDoRequestAcceptedSuccess(t *testing.T) {
testServer := httptest.NewServer(&fakeHandler)
request, _ := http.NewRequest("GET", testServer.URL+"/foo/bar", nil)
auth := AuthInfo{User: "user", Password: "pass"}
c := New(testServer.URL, &auth)
c, err := New(testServer.URL, &auth)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
body, err := c.doRequest(request)
if request.Header["Authorization"] == nil {
t.Errorf("Request is missing authorization header: %#v", *request)
@@ -532,7 +553,7 @@ func TestGetServerVersion(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write(output)
}))
client := New(server.URL, nil)
client := NewOrDie(server.URL, nil)
got, err := client.ServerVersion()
if err != nil {

View File

@@ -56,7 +56,7 @@ func (c *RESTClient) Verb(verb string) *Request {
return &Request{
verb: verb,
c: c,
path: c.Prefix,
path: c.prefix,
sync: c.Sync,
timeout: c.Timeout,
params: map[string]string{},

View File

@@ -45,8 +45,8 @@ func TestDoRequestNewWay(t *testing.T) {
}
testServer := httptest.NewServer(&fakeHandler)
auth := AuthInfo{User: "user", Password: "pass"}
s := New(testServer.URL, &auth)
obj, err := s.Verb("POST").
c := NewOrDie(testServer.URL, &auth)
obj, err := c.Verb("POST").
Path("foo/bar").
Path("baz").
ParseSelectorParam("labels", "name=foo").
@@ -80,8 +80,8 @@ func TestDoRequestNewWayReader(t *testing.T) {
}
testServer := httptest.NewServer(&fakeHandler)
auth := AuthInfo{User: "user", Password: "pass"}
s := New(testServer.URL, &auth)
obj, err := s.Verb("POST").
c := NewOrDie(testServer.URL, &auth)
obj, err := c.Verb("POST").
Path("foo/bar").
Path("baz").
SelectorParam("labels", labels.Set{"name": "foo"}.AsSelector()).
@@ -117,8 +117,8 @@ func TestDoRequestNewWayObj(t *testing.T) {
}
testServer := httptest.NewServer(&fakeHandler)
auth := AuthInfo{User: "user", Password: "pass"}
s := New(testServer.URL, &auth)
obj, err := s.Verb("POST").
c := NewOrDie(testServer.URL, &auth)
obj, err := c.Verb("POST").
Path("foo/bar").
Path("baz").
SelectorParam("labels", labels.Set{"name": "foo"}.AsSelector()).
@@ -167,8 +167,8 @@ func TestDoRequestNewWayFile(t *testing.T) {
}
testServer := httptest.NewServer(&fakeHandler)
auth := AuthInfo{User: "user", Password: "pass"}
s := New(testServer.URL, &auth)
obj, err := s.Verb("POST").
c := NewOrDie(testServer.URL, &auth)
obj, err := c.Verb("POST").
Path("foo/bar").
Path("baz").
ParseSelectorParam("labels", "name=foo").
@@ -192,7 +192,7 @@ func TestDoRequestNewWayFile(t *testing.T) {
}
func TestVerbs(t *testing.T) {
c := New("", nil)
c := NewOrDie("localhost", nil)
if r := c.Post(); r.verb != "POST" {
t.Errorf("Post verb is wrong")
}
@@ -209,7 +209,7 @@ func TestVerbs(t *testing.T) {
func TestAbsPath(t *testing.T) {
expectedPath := "/bar/foo"
c := New("", nil)
c := NewOrDie("localhost", nil)
r := c.Post().Path("/foo").AbsPath(expectedPath)
if r.path != expectedPath {
t.Errorf("unexpected path: %s, expected %s", r.path, expectedPath)
@@ -217,7 +217,7 @@ func TestAbsPath(t *testing.T) {
}
func TestSync(t *testing.T) {
c := New("", nil)
c := NewOrDie("localhost", nil)
r := c.Get()
if r.sync {
t.Errorf("sync has wrong default")
@@ -238,13 +238,13 @@ func TestUintParam(t *testing.T) {
testVal uint64
expectStr string
}{
{"foo", 31415, "?foo=31415"},
{"bar", 42, "?bar=42"},
{"baz", 0, "?baz=0"},
{"foo", 31415, "http://localhost?foo=31415"},
{"bar", 42, "http://localhost?bar=42"},
{"baz", 0, "http://localhost?baz=0"},
}
for _, item := range table {
c := New("", nil)
c := NewOrDie("localhost", nil)
r := c.Get().AbsPath("").UintParam(item.name, item.testVal)
if e, a := item.expectStr, r.finalURL(); e != a {
t.Errorf("expected %v, got %v", e, a)
@@ -263,7 +263,7 @@ func TestUnacceptableParamNames(t *testing.T) {
}
for _, item := range table {
c := New("", nil)
c := NewOrDie("localhost", nil)
r := c.Get().setParam(item.name, item.testVal)
if e, a := item.expectSuccess, r.err == nil; e != a {
t.Errorf("expected %v, got %v (%v)", e, a, r.err)
@@ -272,7 +272,7 @@ func TestUnacceptableParamNames(t *testing.T) {
}
func TestSetPollPeriod(t *testing.T) {
c := New("", nil)
c := NewOrDie("localhost", nil)
r := c.Get()
if r.pollPeriod == 0 {
t.Errorf("polling should be on by default")
@@ -303,12 +303,12 @@ func TestPolling(t *testing.T) {
}))
auth := AuthInfo{User: "user", Password: "pass"}
s := New(testServer.URL, &auth)
c := NewOrDie(testServer.URL, &auth)
trials := []func(){
func() {
// Check that we do indeed poll when asked to.
obj, err := s.Get().PollPeriod(5 * time.Millisecond).Do().Get()
obj, err := c.Get().PollPeriod(5 * time.Millisecond).Do().Get()
if err != nil {
t.Errorf("Unexpected error: %v %#v", err, err)
return
@@ -323,7 +323,7 @@ func TestPolling(t *testing.T) {
},
func() {
// Check that we don't poll when asked not to.
obj, err := s.Get().PollPeriod(0).Do().Get()
obj, err := c.Get().PollPeriod(0).Do().Get()
if err == nil {
t.Errorf("Unexpected non error: %v", obj)
return
@@ -405,7 +405,10 @@ func TestWatch(t *testing.T) {
}
}))
s := New(testServer.URL, &auth)
s, err := New(testServer.URL, &auth)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
watching, err := s.Get().Path("path/to/watch/thing").Watch()
if err != nil {