diff --git a/pkg/subscribe/handler.go b/pkg/subscribe/handler.go index d371957b..a8cab896 100644 --- a/pkg/subscribe/handler.go +++ b/pkg/subscribe/handler.go @@ -33,26 +33,7 @@ func Handler(apiContext *types.APIContext) error { return err } -func handler(apiContext *types.APIContext) error { - c, err := upgrader.Upgrade(apiContext.Response, apiContext.Request, nil) - if err != nil { - return err - } - defer c.Close() - - cancelCtx, cancel := context.WithCancel(apiContext.Request.Context()) - apiContext.Request = apiContext.Request.WithContext(cancelCtx) - - go func() { - for { - if _, _, err := c.NextReader(); err != nil { - cancel() - c.Close() - break - } - } - }() - +func getMatchingSchemas(apiContext *types.APIContext) []*types.Schema { apiVersions := apiContext.Request.URL.Query()["apiVersions"] resourceTypes := apiContext.Request.URL.Query()["resourceTypes"] @@ -69,11 +50,35 @@ func handler(apiContext *types.APIContext) error { } } + return schemas +} + +func handler(apiContext *types.APIContext) error { + schemas := getMatchingSchemas(apiContext) if len(schemas) == 0 { return httperror.NewAPIError(httperror.NotFound, "no resources types matched") } - readerGroup, ctx := errgroup.WithContext(apiContext.Request.Context()) + c, err := upgrader.Upgrade(apiContext.Response, apiContext.Request, nil) + if err != nil { + return err + } + defer c.Close() + + cancelCtx, cancel := context.WithCancel(apiContext.Request.Context()) + readerGroup, ctx := errgroup.WithContext(cancelCtx) + apiContext.Request = apiContext.Request.WithContext(ctx) + + go func() { + for { + if _, _, err := c.NextReader(); err != nil { + cancel() + c.Close() + break + } + } + }() + events := make(chan map[string]interface{}) for _, schema := range schemas { streamStore(ctx, readerGroup, apiContext, schema, events) diff --git a/store/transform/transform.go b/store/transform/transform.go index 48982d8e..80797a7c 100644 --- a/store/transform/transform.go +++ b/store/transform/transform.go @@ -1,6 +1,9 @@ package transform -import "github.com/rancher/norman/types" +import ( + "github.com/rancher/norman/types" + "github.com/rancher/norman/types/convert" +) type TransformerFunc func(apiContext *types.APIContext, data map[string]interface{}) (map[string]interface{}, error) @@ -36,18 +39,13 @@ func (t *Store) Watch(apiContext *types.APIContext, schema *types.Schema, opt *t return t.StreamTransformer(apiContext, c) } - result := make(chan map[string]interface{}) - go func() { - for item := range c { - item, err := t.Transformer(apiContext, item) - if err == nil && item != nil { - result <- item - } + return convert.Chan(c, func(data map[string]interface{}) map[string]interface{} { + item, err := t.Transformer(apiContext, data) + if err != nil { + return nil } - close(result) - }() - - return result, nil + return item + }), nil } func (t *Store) List(apiContext *types.APIContext, schema *types.Schema, opt *types.QueryOptions) ([]map[string]interface{}, error) { diff --git a/store/wrapper/wrapper.go b/store/wrapper/wrapper.go index e345cffc..b501fdc0 100644 --- a/store/wrapper/wrapper.go +++ b/store/wrapper/wrapper.go @@ -3,6 +3,7 @@ package wrapper import ( "github.com/rancher/norman/httperror" "github.com/rancher/norman/types" + "github.com/rancher/norman/types/convert" ) func Wrap(store types.Store) types.Store { @@ -42,20 +43,11 @@ func (s *StoreWrapper) Watch(apiContext *types.APIContext, schema *types.Schema, return nil, err } - result := make(chan map[string]interface{}) - go func() { - for item := range c { - item = apiContext.FilterObject(&types.QueryOptions{ - Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), - }, item) - if item != nil { - result <- item - } - } - close(result) - }() - - return result, nil + return convert.Chan(c, func(data map[string]interface{}) map[string]interface{} { + return apiContext.FilterObject(&types.QueryOptions{ + Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), + }, data) + }), nil } func (s *StoreWrapper) Create(apiContext *types.APIContext, schema *types.Schema, data map[string]interface{}) (map[string]interface{}, error) { diff --git a/types/convert/convert.go b/types/convert/convert.go index a2a6c49a..bb3327aa 100644 --- a/types/convert/convert.go +++ b/types/convert/convert.go @@ -10,6 +10,20 @@ import ( "unicode" ) +func Chan(c <-chan map[string]interface{}, f func(map[string]interface{}) map[string]interface{}) chan map[string]interface{} { + result := make(chan map[string]interface{}) + go func() { + for data := range c { + modified := f(data) + if modified != nil { + result <- modified + } + } + close(result) + }() + return result +} + func Singular(value interface{}) interface{} { if slice, ok := value.([]string); ok { if len(slice) == 0 {