diff --git a/pkg/accesscontrol/access_control.go b/pkg/accesscontrol/access_control.go index 9e29cdbe..7c19fbfe 100644 --- a/pkg/accesscontrol/access_control.go +++ b/pkg/accesscontrol/access_control.go @@ -17,8 +17,8 @@ func NewAccessControl() *AccessControl { func (a *AccessControl) CanWatch(apiOp *types.APIRequest, schema *types.APISchema) error { access := GetAccessListMap(schema) - if !access.Grants("watch", "*", "*") { - return fmt.Errorf("watch not allowed") + if _, ok := access["watch"]; ok { + return nil } - return nil + return fmt.Errorf("watch not allowed") } diff --git a/pkg/accesscontrol/access_set.go b/pkg/accesscontrol/access_set.go index 4e086e34..516dce3f 100644 --- a/pkg/accesscontrol/access_set.go +++ b/pkg/accesscontrol/access_set.go @@ -1,9 +1,12 @@ package accesscontrol import ( + "sort" + "github.com/rancher/steve/pkg/attributes" "github.com/rancher/steve/pkg/schemaserver/types" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/sets" ) type AccessSet struct { @@ -18,6 +21,26 @@ type key struct { gr schema.GroupResource } +func (a *AccessSet) Namespaces() (result []string) { + set := map[string]bool{} + for k, as := range a.set { + if k.verb != "get" && k.verb != "list" { + continue + } + for access := range as { + if access.Namespace == all { + continue + } + set[access.Namespace] = true + } + } + for k := range set { + result = append(result, k) + } + sort.Strings(result) + return +} + func (a *AccessSet) Merge(right *AccessSet) { for k, accessMap := range right.set { m, ok := a.set[k] @@ -36,6 +59,7 @@ func (a *AccessSet) Merge(right *AccessSet) { } func (a AccessSet) AccessListFor(verb string, gr schema.GroupResource) (result AccessList) { + dedup := map[Access]bool{} for _, v := range []string{all, verb} { for _, g := range []string{all, gr.Group} { for _, r := range []string{all, gr.Resource} { @@ -46,12 +70,16 @@ func (a AccessSet) AccessListFor(verb string, gr schema.GroupResource) (result A Resource: r, }, }] { - result = append(result, k) + dedup[k] = true } } } } + for k := range dedup { + result = append(result, k) + } + return } @@ -76,6 +104,41 @@ func (a AccessListByVerb) Grants(verb, namespace, name string) bool { return a[verb].Grants(namespace, name) } +func (a AccessListByVerb) All(verb string) bool { + return a.Grants(verb, all, all) +} + +type Resources struct { + All bool + Names sets.String +} + +func (a AccessListByVerb) Granted(verb string) (result map[string]Resources) { + result = map[string]Resources{} + + // if list, we need to check get also + verbs := []string{verb} + if verb == "list" { + verbs = append(verbs, "get") + } + + for _, verb := range verbs { + for _, access := range a[verb] { + resources := result[access.Namespace] + if access.ResourceName == all { + resources.All = true + } else { + if resources.Names == nil { + resources.Names = sets.String{} + } + resources.Names.Insert(access.ResourceName) + } + result[access.Namespace] = resources + } + } + return result +} + func (a AccessListByVerb) AnyVerb(verb ...string) bool { for _, v := range verb { if len(a[v]) > 0 { diff --git a/pkg/client/factory.go b/pkg/client/factory.go index 95c9aab3..8d393c17 100644 --- a/pkg/client/factory.go +++ b/pkg/client/factory.go @@ -72,15 +72,23 @@ func (p *Factory) MetadataClient() metadata.Interface { } func (p *Factory) Client(ctx *types.APIRequest, s *types.APISchema, namespace string) (dynamic.ResourceInterface, error) { - return p.newClient(ctx, p.clientCfg, s, namespace) + return newClient(ctx, p.clientCfg, s, namespace, p.impersonate) +} + +func (p *Factory) AdminClient(ctx *types.APIRequest, s *types.APISchema, namespace string) (dynamic.ResourceInterface, error) { + return newClient(ctx, p.clientCfg, s, namespace, false) } func (p *Factory) ClientForWatch(ctx *types.APIRequest, s *types.APISchema, namespace string) (dynamic.ResourceInterface, error) { - return p.newClient(ctx, p.watchClientCfg, s, namespace) + return newClient(ctx, p.watchClientCfg, s, namespace, p.impersonate) } -func (p *Factory) newClient(ctx *types.APIRequest, cfg *rest.Config, s *types.APISchema, namespace string) (dynamic.ResourceInterface, error) { - if p.impersonate { +func (p *Factory) AdminClientForWatch(ctx *types.APIRequest, s *types.APISchema, namespace string) (dynamic.ResourceInterface, error) { + return newClient(ctx, p.watchClientCfg, s, namespace, false) +} + +func newClient(ctx *types.APIRequest, cfg *rest.Config, s *types.APISchema, namespace string, impersonate bool) (dynamic.ResourceInterface, error) { + if impersonate { user, ok := request.UserFrom(ctx.Context()) if !ok { return nil, fmt.Errorf("user not found for impersonation") diff --git a/pkg/schema/factory.go b/pkg/schema/factory.go index 6c71fe05..1b3876cc 100644 --- a/pkg/schema/factory.go +++ b/pkg/schema/factory.go @@ -72,7 +72,19 @@ func (c *Collection) schemasForSubject(access *accesscontrol.AccessSet) (*types. } if len(verbAccess) == 0 { - continue + if gr.Group == "" && gr.Resource == "namespaces" { + var accessList accesscontrol.AccessList + for _, ns := range access.Namespaces() { + accessList = append(accessList, accesscontrol.Access{ + Namespace: "*", + ResourceName: ns, + }) + } + verbAccess["list"] = accessList + verbAccess["watch"] = accessList + } else { + continue + } } s = s.DeepCopy() diff --git a/pkg/schemaserver/server/server.go b/pkg/schemaserver/server/server.go index 3426a4f7..2b327176 100644 --- a/pkg/schemaserver/server/server.go +++ b/pkg/schemaserver/server/server.go @@ -165,6 +165,8 @@ func (s *Server) handle(apiOp *types.APIRequest, parser parse.Parser) { apiOp.WriteResponse(code, obj) } else if list, ok := data.(types.APIObjectList); ok { apiOp.WriteResponseList(code, list) + } else if code > http.StatusOK { + apiOp.Response.WriteHeader(code) } } diff --git a/pkg/schemaserver/types/server_types.go b/pkg/schemaserver/types/server_types.go index fe489a68..975177cd 100644 --- a/pkg/schemaserver/types/server_types.go +++ b/pkg/schemaserver/types/server_types.go @@ -26,6 +26,13 @@ type RawResource struct { APIObject APIObject `json:"-" yaml:"-"` } +type Pagination struct { + Limit int `json:"limit,omitempty"` + First string `json:"first,omitempty"` + Next string `json:"next,omitempty"` + Partial bool `json:"partial,omitempty"` +} + func (r *RawResource) MarshalJSON() ([]byte, error) { type r_ RawResource outer, err := json.Marshal((*r_)(r)) @@ -160,6 +167,7 @@ type URLBuilder interface { ResourceLink(schema *APISchema, id string) string Link(schema *APISchema, id string, linkName string) string Action(schema *APISchema, id string, action string) string + Marker(marker string) string RelativeToRoot(path string) string } diff --git a/pkg/schemaserver/types/types.go b/pkg/schemaserver/types/types.go index f46bfb91..35492684 100644 --- a/pkg/schemaserver/types/types.go +++ b/pkg/schemaserver/types/types.go @@ -6,16 +6,13 @@ import ( "github.com/rancher/wrangler/pkg/schemas" ) -const ( - ResourceFieldID = "id" -) - type Collection struct { Type string `json:"type,omitempty"` Links map[string]string `json:"links"` CreateTypes map[string]string `json:"createTypes,omitempty"` Actions map[string]string `json:"actions"` ResourceType string `json:"resourceType"` + Pagination *Pagination `json:"pagination,omitempty"` Revision string `json:"revision,omitempty"` Continue string `json:"continue,omitempty"` } diff --git a/pkg/schemaserver/urlbuilder/url.go b/pkg/schemaserver/urlbuilder/url.go index 86a810b6..880e2bcd 100644 --- a/pkg/schemaserver/urlbuilder/url.go +++ b/pkg/schemaserver/urlbuilder/url.go @@ -62,6 +62,15 @@ type DefaultURLBuilder struct { query url.Values } +func (u *DefaultURLBuilder) Marker(marker string) string { + newValues := url.Values{} + for k, v := range u.query { + newValues[k] = v + } + newValues.Set("continue", marker) + return u.Current() + "?" + newValues.Encode() +} + func (u *DefaultURLBuilder) Link(schema *types.APISchema, id string, linkName string) string { return u.schemaURL(schema, id, linkName) } diff --git a/pkg/schemaserver/writer/encoding.go b/pkg/schemaserver/writer/encoding.go index 812978c1..e8d91ce7 100644 --- a/pkg/schemaserver/writer/encoding.go +++ b/pkg/schemaserver/writer/encoding.go @@ -3,6 +3,7 @@ package writer import ( "io" "net/http" + "strconv" "github.com/rancher/steve/pkg/schemaserver/types" ) @@ -99,6 +100,14 @@ func (j *EncodingResponseWriter) addLinks(schema *types.APISchema, context *type } } +func getLimit(req *http.Request) int { + limit, err := strconv.Atoi(req.Header.Get("limit")) + if err == nil && limit > 0 { + return limit + } + return 0 +} + func newCollection(apiOp *types.APIRequest, list types.APIObjectList) *types.GenericCollection { result := &types.GenericCollection{ Collection: types.Collection{ @@ -114,6 +123,18 @@ func newCollection(apiOp *types.APIRequest, list types.APIObjectList) *types.Gen }, } + partial := list.Continue != "" || apiOp.Query.Get("continue") != "" + if partial { + result.Pagination = &types.Pagination{ + Limit: getLimit(apiOp.Request), + First: apiOp.URLBuilder.Current(), + Partial: true, + } + if list.Continue != "" { + result.Pagination.Next = apiOp.URLBuilder.Marker(list.Continue) + } + } + if apiOp.Method == http.MethodGet { if apiOp.AccessControl.CanCreate(apiOp, apiOp.Schema) == nil { result.CreateTypes[apiOp.Schema.ID] = apiOp.URLBuilder.Collection(apiOp.Schema) diff --git a/pkg/server/handler/apiserver.go b/pkg/server/handler/apiserver.go index f4cb3577..a6258b8a 100644 --- a/pkg/server/handler/apiserver.go +++ b/pkg/server/handler/apiserver.go @@ -12,7 +12,7 @@ import ( "github.com/rancher/steve/pkg/schemaserver/urlbuilder" "github.com/rancher/steve/pkg/server/router" "github.com/sirupsen/logrus" - "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/client-go/rest" ) @@ -57,9 +57,9 @@ type apiServer struct { } func (a *apiServer) common(rw http.ResponseWriter, req *http.Request) (*types.APIRequest, bool) { - user := &user.DefaultInfo{ - Name: "admin", - Groups: []string{"system:masters"}, + user, ok := request.UserFrom(req.Context()) + if !ok { + return nil, false } schemas, err := a.sf.Schemas(user) diff --git a/pkg/server/resources/common/formatter.go b/pkg/server/resources/common/formatter.go index ebef5f73..72845c5b 100644 --- a/pkg/server/resources/common/formatter.go +++ b/pkg/server/resources/common/formatter.go @@ -1,15 +1,16 @@ package common import ( + "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/schema" "github.com/rancher/steve/pkg/schemaserver/types" "github.com/rancher/steve/pkg/server/store/proxy" "k8s.io/apimachinery/pkg/api/meta" ) -func DefaultTemplate(clientGetter proxy.ClientGetter) schema.Template { +func DefaultTemplate(clientGetter proxy.ClientGetter, asl accesscontrol.AccessSetLookup) schema.Template { return schema.Template{ - Store: proxy.NewProxyStore(clientGetter), + Store: proxy.NewProxyStore(clientGetter, asl), Formatter: Formatter, } } diff --git a/pkg/server/resources/counts/counts.go b/pkg/server/resources/counts/counts.go index 7f379fcd..e3f06d6a 100644 --- a/pkg/server/resources/counts/counts.go +++ b/pkg/server/resources/counts/counts.go @@ -232,18 +232,25 @@ func (s *Store) getCount(apiOp *types.APIRequest) Count { for _, schema := range s.schemasToWatch(apiOp) { gvr := attributes.GVR(schema) + access, _ := attributes.Access(schema).(accesscontrol.AccessListByVerb) rev := 0 itemCount := ItemCount{ Namespaces: map[string]int{}, } + all := access.Grants("list", "*", "*") + for _, obj := range s.ccache.List(gvr) { - _, ns, revision, ok := getInfo(obj) + name, ns, revision, ok := getInfo(obj) if !ok { continue } + if !all && !access.Grants("list", ns, name) && !access.Grants("get", ns, name) { + continue + } + if revision > rev { rev = revision } diff --git a/pkg/server/resources/schema.go b/pkg/server/resources/schema.go index 3d8cb354..c9665f77 100644 --- a/pkg/server/resources/schema.go +++ b/pkg/server/resources/schema.go @@ -1,6 +1,7 @@ package resources import ( + "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/client" "github.com/rancher/steve/pkg/clustercache" "github.com/rancher/steve/pkg/schema" @@ -21,8 +22,8 @@ func DefaultSchemas(baseSchema *types.APISchemas, discovery discovery.DiscoveryI return baseSchema } -func DefaultSchemaTemplates(cf *client.Factory) []schema.Template { +func DefaultSchemaTemplates(cf *client.Factory, lookup accesscontrol.AccessSetLookup) []schema.Template { return []schema.Template{ - common.DefaultTemplate(cf), + common.DefaultTemplate(cf, lookup), } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 64a47e01..071b59d9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -53,16 +53,16 @@ func setup(ctx context.Context, server *Server) (http.Handler, *schema.Collectio return nil, nil, err } - ccache := clustercache.NewClusterCache(ctx, cf.MetadataClient()) - - server.BaseSchemas = resources.DefaultSchemas(server.BaseSchemas, server.K8s.Discovery(), ccache) - server.SchemaTemplates = append(server.SchemaTemplates, resources.DefaultSchemaTemplates(cf)...) - asl := server.AccessSetLookup if asl == nil { asl = accesscontrol.NewAccessStore(ctx, true, server.RBAC) } + ccache := clustercache.NewClusterCache(ctx, cf.MetadataClient()) + + server.BaseSchemas = resources.DefaultSchemas(server.BaseSchemas, server.K8s.Discovery(), ccache) + server.SchemaTemplates = append(server.SchemaTemplates, resources.DefaultSchemaTemplates(cf, asl)...) + cols, err := common.NewDynamicColumns(server.RestConfig) if err != nil { return nil, nil, err diff --git a/pkg/server/store/proxy/con_eg.go b/pkg/server/store/proxy/con_eg.go new file mode 100644 index 00000000..56ac4256 --- /dev/null +++ b/pkg/server/store/proxy/con_eg.go @@ -0,0 +1,204 @@ +package proxy + +import ( + "context" + "encoding/base64" + "encoding/json" + + "github.com/rancher/steve/pkg/schemaserver/types" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" +) + +type ParallelPartitionLister struct { + Lister PartitionLister + Concurrency int64 + Partitions []Partition + state *listState + revision string + err error +} + +type PartitionLister func(ctx context.Context, partition Partition, cont string, revision string, limit int) (types.APIObjectList, error) + +func (p *ParallelPartitionLister) Err() error { + return p.err +} + +func (p *ParallelPartitionLister) Revision() string { + return p.revision +} + +func (p *ParallelPartitionLister) Continue() string { + if p.state == nil { + return "" + } + bytes, err := json.Marshal(p.state) + if err != nil { + return "" + } + return base64.StdEncoding.EncodeToString(bytes) +} + +func indexOrZero(partitions []Partition, namespace string) int { + if namespace == "" { + return 0 + } + for i, partition := range partitions { + if partition.Namespace == namespace { + return i + } + } + return 0 +} + +func (p *ParallelPartitionLister) List(ctx context.Context, limit int, resume string) (<-chan []types.APIObject, error) { + var state listState + if resume != "" { + bytes, err := base64.StdEncoding.DecodeString(resume) + if err != nil { + return nil, err + } + if err := json.Unmarshal(bytes, &state); err != nil { + return nil, err + } + + if state.Limit > 0 { + limit = state.Limit + } + } + + result := make(chan []types.APIObject) + go p.feeder(ctx, state, limit, result) + return result, nil +} + +type listState struct { + Revision string `json:"r,omitempty"` + PartitionNamespace string `json:"p,omitempty"` + Continue string `json:"c,omitempty"` + Offset int `json:"o,omitempty"` + Limit int `json:"l,omitempty"` +} + +func (p *ParallelPartitionLister) feeder(ctx context.Context, state listState, limit int, result chan []types.APIObject) { + var ( + sem = semaphore.NewWeighted(p.Concurrency) + capacity = limit + last chan struct{} + ) + + eg, ctx := errgroup.WithContext(ctx) + defer func() { + err := eg.Wait() + if p.err == nil { + p.err = err + } + close(result) + }() + + for i := indexOrZero(p.Partitions, state.PartitionNamespace); i < len(p.Partitions); i++ { + if capacity <= 0 || isDone(ctx) { + break + } + + var ( + partition = p.Partitions[i] + tickets = int64(1) + turn = last + next = make(chan struct{}) + ) + + // setup a linked list of channel to control insertion order + last = next + + if state.Revision == "" { + // don't have a revision yet so grab all tickets to set a revision + tickets = 3 + } + if err := sem.Acquire(ctx, tickets); err != nil { + p.err = err + break + } + + // make state local + state := state + eg.Go(func() error { + defer sem.Release(tickets) + defer close(next) + + for { + cont := "" + if partition.Namespace == state.PartitionNamespace { + cont = state.Continue + } + list, err := p.Lister(ctx, partition, cont, state.Revision, limit) + if err != nil { + return err + } + + waitForTurn(ctx, turn) + if p.state != nil { + return nil + } + + if state.Revision == "" { + state.Revision = list.Revision + } + + if p.revision == "" { + p.revision = list.Revision + } + + if state.PartitionNamespace == partition.Namespace && state.Offset > 0 && state.Offset < len(list.Objects) { + list.Objects = list.Objects[state.Offset:] + } + + if len(list.Objects) > capacity { + result <- list.Objects[:capacity] + // save state to redo this list at this offset + p.state = &listState{ + Revision: list.Revision, + PartitionNamespace: partition.Namespace, + Continue: cont, + Offset: capacity, + Limit: limit, + } + capacity = 0 + return nil + } else { + result <- list.Objects + capacity -= len(list.Objects) + if list.Continue == "" { + return nil + } + // loop again and get more data + state.Continue = list.Continue + state.PartitionNamespace = partition.Namespace + state.Offset = 0 + } + } + }) + } + + p.err = eg.Wait() +} + +func waitForTurn(ctx context.Context, turn chan struct{}) { + if turn == nil { + return + } + select { + case <-turn: + case <-ctx.Done(): + } +} + +func isDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/pkg/server/store/proxy/proxy_store.go b/pkg/server/store/proxy/proxy_store.go index e5c71a74..b17a0e6d 100644 --- a/pkg/server/store/proxy/proxy_store.go +++ b/pkg/server/store/proxy/proxy_store.go @@ -9,6 +9,7 @@ import ( "regexp" "github.com/pkg/errors" + "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/attributes" "github.com/rancher/steve/pkg/schemaserver/types" "github.com/rancher/wrangler/pkg/data" @@ -19,6 +20,7 @@ import ( "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" apitypes "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/dynamic" ) @@ -29,17 +31,24 @@ var ( type ClientGetter interface { Client(ctx *types.APIRequest, schema *types.APISchema, namespace string) (dynamic.ResourceInterface, error) + AdminClient(ctx *types.APIRequest, schema *types.APISchema, namespace string) (dynamic.ResourceInterface, error) ClientForWatch(ctx *types.APIRequest, schema *types.APISchema, namespace string) (dynamic.ResourceInterface, error) + AdminClientForWatch(ctx *types.APIRequest, schema *types.APISchema, namespace string) (dynamic.ResourceInterface, error) } type Store struct { clientGetter ClientGetter } -func NewProxyStore(clientGetter ClientGetter) types.Store { +func NewProxyStore(clientGetter ClientGetter, lookup accesscontrol.AccessSetLookup) types.Store { return &errorStore{ - Store: &Store{ - clientGetter: clientGetter, + Store: &WatchRefresh{ + Store: &RBACStore{ + Store: &Store{ + clientGetter: clientGetter, + }, + }, + asl: lookup, }, } } @@ -177,18 +186,43 @@ func tableToObjects(obj map[string]interface{}) []unstructured.Unstructured { return result } -func (s *Store) List(apiOp *types.APIRequest, schema *types.APISchema) (types.APIObjectList, error) { - k8sClient, err := s.clientGetter.Client(apiOp, schema, apiOp.Namespace) +func (s *Store) ByNames(apiOp *types.APIRequest, schema *types.APISchema, names sets.String) (types.APIObjectList, error) { + adminClient, err := s.clientGetter.AdminClient(apiOp, schema, apiOp.Namespace) if err != nil { return types.APIObjectList{}, err } + objs, err := s.list(apiOp, schema, adminClient) + if err != nil { + return types.APIObjectList{}, err + } + + var filtered []types.APIObject + for _, obj := range objs.Objects { + if names.Has(obj.Name()) { + filtered = append(filtered, obj) + } + } + + objs.Objects = filtered + return objs, nil +} + +func (s *Store) List(apiOp *types.APIRequest, schema *types.APISchema) (types.APIObjectList, error) { + client, err := s.clientGetter.Client(apiOp, schema, apiOp.Namespace) + if err != nil { + return types.APIObjectList{}, err + } + return s.list(apiOp, schema, client) +} + +func (s *Store) list(apiOp *types.APIRequest, schema *types.APISchema, client dynamic.ResourceInterface) (types.APIObjectList, error) { opts := metav1.ListOptions{} if err := decodeParams(apiOp, &opts); err != nil { return types.APIObjectList{}, nil } - resultList, err := k8sClient.List(opts) + resultList, err := client.List(opts) if err != nil { return types.APIObjectList{}, err } @@ -255,15 +289,41 @@ func (s *Store) listAndWatch(apiOp *types.APIRequest, k8sClient dynamic.Resource } } -func (s *Store) Watch(apiOp *types.APIRequest, schema *types.APISchema, w types.WatchRequest) (chan types.APIEvent, error) { - k8sClient, err := s.clientGetter.ClientForWatch(apiOp, schema, apiOp.Namespace) +func (s *Store) WatchNames(apiOp *types.APIRequest, schema *types.APISchema, w types.WatchRequest, names sets.String) (chan types.APIEvent, error) { + adminClient, err := s.clientGetter.ClientForWatch(apiOp, schema, apiOp.Namespace) + if err != nil { + return nil, err + } + c, err := s.watch(apiOp, schema, w, adminClient) if err != nil { return nil, err } result := make(chan types.APIEvent) go func() { - s.listAndWatch(apiOp, k8sClient, schema, w, result) + defer close(result) + for item := range c { + if item.Error != nil && names.Has(item.Object.Name()) { + result <- item + } + } + }() + + return result, nil +} + +func (s *Store) Watch(apiOp *types.APIRequest, schema *types.APISchema, w types.WatchRequest) (chan types.APIEvent, error) { + client, err := s.clientGetter.ClientForWatch(apiOp, schema, apiOp.Namespace) + if err != nil { + return nil, err + } + return s.watch(apiOp, schema, w, client) +} + +func (s *Store) watch(apiOp *types.APIRequest, schema *types.APISchema, w types.WatchRequest, client dynamic.ResourceInterface) (chan types.APIEvent, error) { + result := make(chan types.APIEvent) + go func() { + s.listAndWatch(apiOp, client, schema, w, result) logrus.Debugf("closing watcher for %s", schema.ID) close(result) }() diff --git a/pkg/server/store/proxy/rbac_store.go b/pkg/server/store/proxy/rbac_store.go new file mode 100644 index 00000000..dfedf0da --- /dev/null +++ b/pkg/server/store/proxy/rbac_store.go @@ -0,0 +1,182 @@ +package proxy + +import ( + "context" + "net/http" + "sort" + "strconv" + + "github.com/rancher/steve/pkg/accesscontrol" + "github.com/rancher/steve/pkg/attributes" + "github.com/rancher/steve/pkg/schemaserver/types" + "golang.org/x/sync/errgroup" + "k8s.io/apimachinery/pkg/util/sets" +) + +type RBACStore struct { + *Store +} + +type Partition struct { + Namespace string + All bool + Names sets.String +} + +func isPassthrough(apiOp *types.APIRequest, schema *types.APISchema, verb string) ([]Partition, bool) { + accessListByVerb, _ := attributes.Access(schema).(accesscontrol.AccessListByVerb) + if accessListByVerb.All(verb) { + return nil, true + } + + resources := accessListByVerb.Granted(verb) + if apiOp.Namespace != "" { + if resources[apiOp.Namespace].All { + return nil, true + } else { + return []Partition{ + { + Namespace: apiOp.Namespace, + Names: resources[apiOp.Namespace].Names, + }, + }, false + } + } + + var result []Partition + + if attributes.Namespaced(schema) { + for k, v := range resources { + result = append(result, Partition{ + Namespace: k, + All: v.All, + Names: v.Names, + }) + } + } else { + for _, v := range resources { + result = append(result, Partition{ + All: v.All, + Names: v.Names, + }) + } + } + + return result, false +} + +func (r *RBACStore) List(apiOp *types.APIRequest, schema *types.APISchema) (types.APIObjectList, error) { + partitions, passthrough := isPassthrough(apiOp, schema, "list") + if passthrough { + return r.Store.List(apiOp, schema) + } + + resume := apiOp.Request.URL.Query().Get("continue") + limit := getLimit(apiOp.Request) + + sort.Slice(partitions, func(i, j int) bool { + return partitions[i].Namespace < partitions[j].Namespace + }) + + lister := &ParallelPartitionLister{ + Lister: func(ctx context.Context, partition Partition, cont string, revision string, limit int) (types.APIObjectList, error) { + return r.list(apiOp, schema, partition, cont, revision, limit) + }, + Concurrency: 3, + Partitions: partitions, + } + + result := types.APIObjectList{} + items, err := lister.List(apiOp.Context(), limit, resume) + if err != nil { + return result, err + } + + for item := range items { + result.Objects = append(result.Objects, item...) + } + + result.Continue = lister.Continue() + result.Revision = lister.Revision() + return result, lister.Err() +} + +func getLimit(req *http.Request) int { + limitString := req.URL.Query().Get("limit") + limit, err := strconv.Atoi(limitString) + if err != nil { + limit = 0 + } + if limit <= 0 { + limit = 100000 + } + return limit +} + +func (r *RBACStore) list(apiOp *types.APIRequest, schema *types.APISchema, partition Partition, cont, revision string, limit int) (types.APIObjectList, error) { + req := *apiOp + req.Namespace = partition.Namespace + req.Request = req.Request.Clone(apiOp.Context()) + + values := req.Request.URL.Query() + values.Set("continue", cont) + values.Set("revision", revision) + if limit > 0 { + values.Set("limit", strconv.Itoa(limit)) + } else { + values.Del("limit") + } + req.Request.URL.RawQuery = values.Encode() + + if partition.All { + return r.Store.List(&req, schema) + } + return r.Store.ByNames(&req, schema, partition.Names) +} + +func (r *RBACStore) Watch(apiOp *types.APIRequest, schema *types.APISchema, w types.WatchRequest) (chan types.APIEvent, error) { + partitions, passthrough := isPassthrough(apiOp, schema, "watch") + if passthrough { + return r.Store.Watch(apiOp, schema, w) + } + + ctx, cancel := context.WithCancel(apiOp.Context()) + apiOp = apiOp.WithContext(ctx) + + eg := errgroup.Group{} + response := make(chan types.APIEvent) + for _, partition := range partitions { + partition := partition + eg.Go(func() error { + defer cancel() + + var ( + c chan types.APIEvent + err error + ) + + req := *apiOp + req.Namespace = partition.Namespace + if partition.All { + c, err = r.Store.Watch(&req, schema, w) + } else { + c, err = r.Store.WatchNames(&req, schema, w, partition.Names) + } + if err != nil { + return err + } + for i := range c { + response <- i + } + return nil + }) + } + + go func() { + defer close(response) + <-ctx.Done() + eg.Wait() + }() + + return response, nil +} diff --git a/pkg/server/store/proxy/watch_refresh.go b/pkg/server/store/proxy/watch_refresh.go new file mode 100644 index 00000000..47fe94da --- /dev/null +++ b/pkg/server/store/proxy/watch_refresh.go @@ -0,0 +1,45 @@ +package proxy + +import ( + "context" + "time" + + "github.com/rancher/steve/pkg/accesscontrol" + "github.com/rancher/steve/pkg/schemaserver/types" + "k8s.io/apiserver/pkg/endpoints/request" +) + +type WatchRefresh struct { + types.Store + asl accesscontrol.AccessSetLookup +} + +func (w *WatchRefresh) Watch(apiOp *types.APIRequest, schema *types.APISchema, wr types.WatchRequest) (chan types.APIEvent, error) { + user, ok := request.UserFrom(apiOp.Context()) + if !ok { + return w.Store.Watch(apiOp, schema, wr) + } + + as := w.asl.AccessFor(user) + ctx, cancel := context.WithCancel(apiOp.Context()) + apiOp = apiOp.WithContext(ctx) + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + } + + newAs := w.asl.AccessFor(user) + if as.ID != newAs.ID { + // RBAC changed + cancel() + return + } + } + }() + + return w.Store.Watch(apiOp, schema, wr) +}