diff --git a/api/handler/query.go b/api/handler/query.go index f3e8ebf6..2c544984 100644 --- a/api/handler/query.go +++ b/api/handler/query.go @@ -7,12 +7,12 @@ import ( "github.com/rancher/norman/types/convert" ) -func QueryFilter(opts *types.QueryOptions, data []map[string]interface{}) []map[string]interface{} { - return ApplyQueryOptions(opts, data) +func QueryFilter(opts *types.QueryOptions, schema *types.Schema, data []map[string]interface{}) []map[string]interface{} { + return ApplyQueryOptions(opts, schema, data) } -func ApplyQueryOptions(options *types.QueryOptions, data []map[string]interface{}) []map[string]interface{} { - data = ApplyQueryConditions(options.Conditions, data) +func ApplyQueryOptions(options *types.QueryOptions, schema *types.Schema, data []map[string]interface{}) []map[string]interface{} { + data = ApplyQueryConditions(options.Conditions, schema, data) data = ApplySort(options.Sort, data) return ApplyPagination(options.Pagination, data) } @@ -35,13 +35,13 @@ func ApplySort(sortOpts types.Sort, data []map[string]interface{}) []map[string] return data } -func ApplyQueryConditions(conditions []*types.QueryCondition, data []map[string]interface{}) []map[string]interface{} { +func ApplyQueryConditions(conditions []*types.QueryCondition, schema *types.Schema, data []map[string]interface{}) []map[string]interface{} { var result []map[string]interface{} outer: for _, item := range data { for _, condition := range conditions { - if !condition.Valid(item) { + if !condition.Valid(schema, item) { continue outer } } diff --git a/store/wrapper/wrapper.go b/store/wrapper/wrapper.go index fdbe0f57..24a9f7b2 100644 --- a/store/wrapper/wrapper.go +++ b/store/wrapper/wrapper.go @@ -28,7 +28,7 @@ func (s *StoreWrapper) ByID(apiContext *types.APIContext, schema *types.Schema, return apiContext.FilterObject(&types.QueryOptions{ Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), - }, data), nil + }, schema, data), nil } func (s *StoreWrapper) List(apiContext *types.APIContext, schema *types.Schema, opts *types.QueryOptions) ([]map[string]interface{}, error) { @@ -38,7 +38,7 @@ func (s *StoreWrapper) List(apiContext *types.APIContext, schema *types.Schema, return nil, err } - return apiContext.FilterList(opts, data), nil + return apiContext.FilterList(opts, schema, data), nil } func (s *StoreWrapper) Watch(apiContext *types.APIContext, schema *types.Schema, opt *types.QueryOptions) (chan map[string]interface{}, error) { @@ -50,7 +50,7 @@ func (s *StoreWrapper) Watch(apiContext *types.APIContext, schema *types.Schema, return convert.Chan(c, func(data map[string]interface{}) map[string]interface{} { return apiContext.FilterObject(&types.QueryOptions{ Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), - }, data) + }, schema, data) }), nil } @@ -83,7 +83,7 @@ func (s *StoreWrapper) Update(apiContext *types.APIContext, schema *types.Schema return apiContext.FilterObject(&types.QueryOptions{ Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), - }, data), nil + }, schema, data), nil } func (s *StoreWrapper) Delete(apiContext *types.APIContext, schema *types.Schema, id string) (map[string]interface{}, error) { @@ -107,7 +107,7 @@ func validateGet(apiContext *types.APIContext, schema *types.Schema, id string) if apiContext.Filter(&types.QueryOptions{ Conditions: apiContext.SubContextAttributeProvider.Query(apiContext, schema), - }, existing) == nil { + }, schema, existing) == nil { return httperror.NewAPIError(httperror.NotFound, "failed to find "+id) } diff --git a/types/condition.go b/types/condition.go index 64dbd433..9be6210b 100644 --- a/types/condition.go +++ b/types/condition.go @@ -39,35 +39,44 @@ type QueryCondition struct { left, right *QueryCondition } -func (q *QueryCondition) Valid(data map[string]interface{}) bool { +func (q *QueryCondition) Valid(schema *Schema, data map[string]interface{}) bool { switch q.conditionType { case CondAnd: if q.left == nil || q.right == nil { return false } - return q.left.Valid(data) && q.right.Valid(data) + return q.left.Valid(schema, data) && q.right.Valid(schema, data) case CondOr: if q.left == nil || q.right == nil { return false } - return q.left.Valid(data) || q.right.Valid(data) + return q.left.Valid(schema, data) || q.right.Valid(schema, data) case CondEQ: - return q.Value == convert.ToString(data[q.Field]) + return q.Value == convert.ToString(valueOrDefault(schema, data, q)) case CondNE: - return q.Value != convert.ToString(data[q.Field]) + return q.Value != convert.ToString(valueOrDefault(schema, data, q)) case CondIn: - return q.Values[convert.ToString(data[q.Field])] + return q.Values[convert.ToString(valueOrDefault(schema, data, q))] case CondNotIn: - return !q.Values[convert.ToString(data[q.Field])] + return !q.Values[convert.ToString(valueOrDefault(schema, data, q))] case CondNotNull: - return convert.ToString(data[q.Field]) != "" + return convert.ToString(valueOrDefault(schema, data, q)) != "" case CondNull: - return convert.ToString(data[q.Field]) == "" + return convert.ToString(valueOrDefault(schema, data, q)) == "" } return false } +func valueOrDefault(schema *Schema, data map[string]interface{}, q *QueryCondition) interface{} { + value := data[q.Field] + if value == nil { + value = schema.ResourceFields[q.Field].Default + } + + return value +} + func (q *QueryCondition) ToCondition() Condition { cond := Condition{ Modifier: q.conditionType.Name, diff --git a/types/server_types.go b/types/server_types.go index 74f27d7b..3b11ead7 100644 --- a/types/server_types.go +++ b/types/server_types.go @@ -49,7 +49,7 @@ type ActionHandler func(actionName string, action *Action, request *APIContext) type RequestHandler func(request *APIContext, next RequestHandler) error -type QueryFilter func(opts *QueryOptions, data []map[string]interface{}) []map[string]interface{} +type QueryFilter func(opts *QueryOptions, schema *Schema, data []map[string]interface{}) []map[string]interface{} type Validator func(request *APIContext, schema *Schema, data map[string]interface{}) error @@ -127,25 +127,25 @@ func (r *APIContext) WriteResponse(code int, obj interface{}) { r.ResponseWriter.Write(r, code, obj) } -func (r *APIContext) FilterList(opts *QueryOptions, obj []map[string]interface{}) []map[string]interface{} { - return r.QueryFilter(opts, obj) +func (r *APIContext) FilterList(opts *QueryOptions, schema *Schema, obj []map[string]interface{}) []map[string]interface{} { + return r.QueryFilter(opts, schema, obj) } -func (r *APIContext) FilterObject(opts *QueryOptions, obj map[string]interface{}) map[string]interface{} { +func (r *APIContext) FilterObject(opts *QueryOptions, schema *Schema, obj map[string]interface{}) map[string]interface{} { opts.Pagination = nil - result := r.QueryFilter(opts, []map[string]interface{}{obj}) + result := r.QueryFilter(opts, schema, []map[string]interface{}{obj}) if len(result) == 0 { return nil } return result[0] } -func (r *APIContext) Filter(opts *QueryOptions, obj interface{}) interface{} { +func (r *APIContext) Filter(opts *QueryOptions, schema *Schema, obj interface{}) interface{} { switch v := obj.(type) { case []map[string]interface{}: - return r.FilterList(opts, v) + return r.FilterList(opts, schema, v) case map[string]interface{}: - return r.FilterObject(opts, v) + return r.FilterObject(opts, schema, v) } return nil