1
0
mirror of https://github.com/rancher/steve.git synced 2025-05-11 01:16:42 +00:00

[v0.3] Migrate SQL cache to Steve ()

* Migrate SQLcache to Steve

* Fix imports

* go mod tidy

* Fix lint errors

* Remove lasso SQL cache mentions

* Fix more CI lint errors

* fix goimports

Signed-off-by: Silvio Moioli <silvio@moioli.net>

* Fix more linting errors

* More lint fix

* Add envtest support

---------

Signed-off-by: Silvio Moioli <silvio@moioli.net>
Co-authored-by: Silvio Moioli <silvio@moioli.net>
This commit is contained in:
Tom Lebreux 2025-02-04 12:41:59 -05:00 committed by GitHub
parent d50101289f
commit 9741028761
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 9999 additions and 32 deletions

View File

@ -25,6 +25,8 @@ jobs:
uses: golangci/golangci-lint-action@a4f60bb28d35aeee14e6880718e0c85ff1882e64 # v6.0.1
with:
version: v1.63.4
- name: Install env-test
run: go install sigs.k8s.io/controller-runtime/tools/setup-envtest@latest
- name: Build
run: make build-bin
- name: Test

6
go.mod
View File

@ -82,6 +82,9 @@ require (
k8s.io/klog v1.0.0
k8s.io/kube-aggregator v0.30.1
k8s.io/kube-openapi v0.0.0-20240411171206-dc4e619f62f3
k8s.io/utils v0.0.0-20240711033017-18e509b52bc8
modernc.org/sqlite v1.29.10
sigs.k8s.io/controller-runtime v0.19.0
)
require (
@ -95,6 +98,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch v5.9.0+incompatible // indirect
github.com/evanphx/json-patch/v5 v5.9.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
@ -150,12 +154,10 @@ require (
gopkg.in/yaml.v2 v2.4.0 // indirect
k8s.io/component-base v0.30.1 // indirect
k8s.io/klog/v2 v2.130.1 // indirect
k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect
modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect
modernc.org/libc v1.49.3 // indirect
modernc.org/mathutil v1.6.0 // indirect
modernc.org/memory v1.8.0 // indirect
modernc.org/sqlite v1.29.10 // indirect
modernc.org/strutil v1.2.0 // indirect
modernc.org/token v1.1.0 // indirect
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0 // indirect

17
go.sum
View File

@ -683,12 +683,16 @@ github.com/envoyproxy/protoc-gen-validate v0.10.0/go.mod h1:DRjgyB0I43LtJapqN6Ni
github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls=
github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg=
github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
@ -714,6 +718,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ=
github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg=
github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
github.com/go-openapi/jsonreference v0.20.1/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k=
@ -740,6 +746,7 @@ github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -1096,6 +1103,10 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
@ -1133,6 +1144,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w=
golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@ -1546,6 +1559,8 @@ golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw=
gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
@ -1965,6 +1980,8 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0 h1:/U5vjBbQn3RCh
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0/go.mod h1:z7+wmGM2dfIiLRfrC6jb5kV2Mq/sK1ZP303cxzkV5Y4=
sigs.k8s.io/cli-utils v0.37.2 h1:GOfKw5RV2HDQZDJlru5KkfLO1tbxqMoyn1IYUxqBpNg=
sigs.k8s.io/cli-utils v0.37.2/go.mod h1:V+IZZr4UoGj7gMJXklWBg6t5xbdThFBcpj4MrZuCYco=
sigs.k8s.io/controller-runtime v0.18.5 h1:nTHio/W+Q4aBlQMgbnC5hZb4IjIidyrizMai9P6n4Rk=
sigs.k8s.io/controller-runtime v0.18.5/go.mod h1:TVoGrfdpbA9VRFaRnKgk9P5/atA0pMwq+f+msb9M8Sg=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo=
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0=
sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E=

View File

@ -67,7 +67,7 @@ type Options struct {
AggregationSecretName string
ClusterRegistry string
ServerVersion string
// SQLCache enables the SQLite-based lasso caching mechanism
// SQLCache enables the SQLite-based caching mechanism
SQLCache bool
}

157
pkg/sqlcache/Readme.md Normal file
View File

@ -0,0 +1,157 @@
# SQL Cache
## Sections
- [ListOptions Informer](#listoptions-informer)
- [List Options](#list-options)
- [ListOption Indexer](#listoptions-indexer)
- [SQL Store](#sql-store)
- [Partitions](#partitions)
- [How to Use](#how-to-use)
- [Technical Information](#technical-information)
- [SQL Tables](#sql-tables)
- [SQLite Driver](#sqlite-driver)
- [Connection Pooling](#connection-pooling)
- [Encryption Defaults](#encryption-defaults)
- [Indexed Fields](#indexed-fields)
- [ListOptions Behavior](#listoptions-behavior)
- [Troubleshooting Sqlite](#troubleshooting-sqlite)
## ListOptions Informer
The main usable feature from the SQL cache is the ListOptions Informer. The ListOptionsInformer provides listing functionality,
like any other informer, but with a wider array of options. The options are configured by informer.ListOptions.
### List Options
ListOptions includes the following:
* Match filters for indexed fields. Filters are for specifying the value a given field in an object should be in order to
be included in the list. Filters can be set to equals or not equals. Filters can be set to look for partial matches or
exact (strict) matches. Filters can be OR'd and AND'd with one another. Filters only work on fields that have been indexed.
* Primary field and secondary field sorting order. Can choose up to two fields to sort on. Sort order can be ascending
or descending. Default sorting is to sort on metadata.namespace in ascending first and then sort on metadata.name.
* Page size to specify how many items to include in a response.
* Page number to specify offset. For example, a page size of 50 and a page number of 2, will return items starting at
index 50. Index will be dependent on sort. Page numbers start at 1.
### ListOptions Factory
The ListOptions Factory helps manage multiple ListOption Informers. A user can call Factory.InformerFor(), to create new
ListOptions informers if they do not exist and retrieve existing ones.
### ListOptions Indexer
Like all other informers, the ListOptions informer uses an indexer to cache objects of the informer's type. A few features
set the ListOptions Indexer apart from others indexers:
* an on-disk store instead of an in-memory store.
* accepts list options backed by SQL queries for extended search/filter/sorting capability.
* AES GCM encryption using key hierarchy.
### SQL Store
The SQL store is the main interface for interacting with the database. This store backs the indexer, and provides all
functionality required by the cache.Store interface.
### Partitions
Partitions are constraints for ListOptionsInform ListByOptions() method that are separate from ListOptions. Partitions
are strict conditions that dictate which namespaces or names can be searched from. These overrule ListOptions and are
intended to be used as a way of enforcing RBAC.
## How to Use
```go
package main
import(
"k8s.io/client-go/dynamic"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/informer/factory"
)
func main() {
cacheFactory, err := factory.NewCacheFactory()
if err != nil {
panic(err)
}
// config should be some rest config created from kubeconfig
// there are other ways to create a config and any client that conforms to k8s.io/client-go/dynamic.ResourceInterface
// will work.
client, err := dynamic.NewForConfig(config)
if err != nil {
panic(err)
}
fields := [][]string{{"metadata", "name"}, {"metadata", "namespace"}}
opts := &informer.ListOptions{}
// gvk should be of type k8s.io/apimachinery/pkg/runtime/schema.GroupVersionKind
c, err := cacheFactory.CacheFor(fields, client, gvk)
if err != nil {
panic(err)
}
// continueToken will just be an offset that can be used in Resume on a subsequent request to continue
// to next page
list, continueToken, err := c.ListByOptions(apiOp.Context(), opts, partitions, namespace)
if err != nil {
panic(err)
}
}
```
## Technical Information
### SQL Tables
There are three tables that are created for the ListOption informer:
* object table - this contains objects, including all their fields, as blobs. These blobs may be encrypted.
* fields table - this contains specific fields of value for objects. These are specified on informer create and are fields
that it is desired to filter or order on.
* indices table - the indices table stores indexes created and objects' values for each index. This backs the generic indexer
that contains the functionality needed to conform to cache.Indexer.
### SQLite Driver
There are multiple SQLite drivers that this package could have used. One of the most, if not the most, popular SQLite golang
drivers is [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). This driver is not being used because it requires enabling
the cgo option when compiling and at the moment steve's main consumer, rancher, does not compile with cgo. We did not want
the SQL informer to be the sole driver in switching to using cgo. Instead, modernc's driver which is in pure golang. Side-by-side
comparisons can be found indicating the cgo version is, as expected, more performant. If in the future it is deemed worthwhile
then the driver can be easily switched by replacing the empty import in `pkg/cache/sql/store` from `_ "modernc.org/sqlite"` to `_ "github.com/mattn/go-sqlite3"`.
### Connection Pooling
While working with the `database/sql` package for go, it is important to understand how sql.Open() and other methods manage
connections. Open starts a connection pool; that is to say after calling open once, there may be anywhere from zero to many
connections attached to a sql.Connection. `database/sql` manages this connection pool under the hood. In most cases, an
application only need one sql.Connection, although sometimes application use two: one for writes, the other for reads. To
read more about the `sql` package's connection pooling read [Managing connections](https://go.dev/doc/database/manage-connections).
The use of connection pooling and the fact that steve potentially has many go routines accessing the same connection pool,
means we have to be careful with writes. Exclusively using sql transaction to write helps ensure safety. To read more about
sql transactions read SQLite's [Transaction docs](https://www.sqlite.org/lang_transaction.html).
### Encryption Defaults
By default only specified types are encrypted. These types are hard-coded and defined by defaultEncryptedResourceTypes
in `pkg/cache/sql/informer/factory/informer_factory.go`. To enabled encryption for all types, set the ENV variable
`CATTLE_ENCRYPT_CACHE_ALL` to "true".
The key size used is 256 bits. Data-encryption-keys are stored in the object table and are rotated every 150,000 writes.
### Indexed Fields
Filtering and sorting only work on indexed fields. These fields are defined when using `CacheFor`. Objects will
have the following indexes by default:
* Fields in informer.defaultIndexedFields
* Fields passed to InformerFor()
### ListOptions Behavior
Defaults:
* Sort.PrimaryField: `metadata.namespace`
* Sort.SecondaryField: `metadata.name`
* Sort.PrimaryOrder: `ASC` (ascending)
* Sort.SecondaryOrder: `ASC` (ascending)
* All filters have partial matching set to false by default
There are some uncommon ways someone could use ListOptions where it would be difficult to predict what the result would be.
Below is a non-exhaustive list of some of these cases and what the behavior is:
* Setting Pagination.Page but not Pagination.PageSize will cause Page to be ignored
* Setting Sort.SecondaryField only will sort as though it was Sort.PrimaryField. Sort.SecondaryOrder will still be applied
and Sort.PrimaryOrder will be ignored
### Writing Secure Queries
Values should be supplied to SQL queries using placeholders, read [Avoiding SQL Injection Risk](https://go.dev/doc/database/sql-injection). Any other portions
of a query that may be user supplied, such as columns, should be carefully validated against a fixed set of acceptable values.
### Troubleshooting SQLite
A useful tool for troubleshooting the database files is the sqlite command line tool. Another useful tool is the goland
sqlite plugin. Both of these tools can be used with the database files.

374
pkg/sqlcache/db/client.go Normal file
View File

@ -0,0 +1,374 @@
/*
Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting,
and a way to reset the database.
*/
package db
import (
"bytes"
"context"
"database/sql"
"encoding/gob"
"fmt"
"io/fs"
"os"
"reflect"
"sync"
"github.com/pkg/errors"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
// needed for drivers
_ "modernc.org/sqlite"
)
const (
// InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve
InformerObjectCacheDBPath = "informer_object_cache.db"
informerObjectCachePerms fs.FileMode = 0o600
)
// Client is a database client that provides encrypting, decrypting, and database resetting.
type Client struct {
conn Connection
connLock sync.RWMutex
encryptor Encryptor
decryptor Decryptor
}
// Connection represents a connection pool.
type Connection interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Exec(query string, args ...any) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Close() error
}
// Closable Closes an underlying connection and returns an error on failure.
type Closable interface {
Close() error
}
// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them.
type Rows interface {
Next() bool
Err() error
Close() error
Scan(dest ...any) error
}
// QueryError encapsulates an error while executing a query
type QueryError struct {
QueryString string
Err error
}
// Error returns a string representation of this QueryError
func (e *QueryError) Error() string {
return "while executing query: " + e.QueryString + " got error: " + e.Err.Error()
}
// Unwrap returns the underlying error
func (e *QueryError) Unwrap() error {
return e.Err
}
// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed.
type TXClient interface {
StmtExec(stmt transaction.Stmt, args ...any) error
Exec(stmt string, args ...any) error
Commit() error
Stmt(stmt *sql.Stmt) transaction.Stmt
Cancel() error
}
// Encryptor encrypts data with a key which is rotated to avoid wear-out.
type Encryptor interface {
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
Encrypt([]byte) ([]byte, []byte, uint32, error)
}
// Decryptor decrypts data previously encrypted by Encryptor.
type Decryptor interface {
// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error.
Decrypt([]byte, []byte, uint32) ([]byte, error)
}
// NewClient returns a Client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) {
client := &Client{
encryptor: encryptor,
decryptor: decryptor,
}
if c != nil {
client.conn = c
return client, nil
}
err := client.NewConnection()
if err != nil {
return nil, err
}
return client, nil
}
// Prepare prepares the given string into a sql statement on the client's connection.
func (c *Client) Prepare(stmt string) *sql.Stmt {
c.connLock.RLock()
defer c.connLock.RUnlock()
prepared, err := c.conn.Prepare(stmt)
if err != nil {
panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err))
}
return prepared
}
// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried
// given a sqlite busy error.
func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
return stmt.QueryContext(ctx, params...)
}
// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant
// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection.
func (c *Client) CloseStmt(closable Closable) error {
return closable.Close()
}
// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type,
// and returns a slice of those objects.
func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
var result []any
for rows.Next() {
data, err := c.decryptScan(rows, shouldDecrypt)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
singleResult, err := fromBytes(data, typ)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
result = append(result, singleResult.Elem().Interface())
}
err := rows.Err()
if err != nil {
return nil, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return nil, err
}
return result, nil
}
// ReadStrings scans the given rows into strings, and then returns the strings as a slice.
func (c *Client) ReadStrings(rows Rows) ([]string, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
var result []string
for rows.Next() {
var key string
err := rows.Scan(&key)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
result = append(result, key)
}
err := rows.Err()
if err != nil {
return nil, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return nil, err
}
return result, nil
}
// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries)
func (c *Client) ReadInt(rows Rows) (int, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
if !rows.Next() {
return 0, closeRowsOnError(rows, sql.ErrNoRows)
}
var result int
err := rows.Scan(&result)
if err != nil {
return 0, closeRowsOnError(rows, err)
}
err = rows.Err()
if err != nil {
return 0, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return 0, err
}
return result, nil
}
// BeginTx attempts to begin a transaction.
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (eg. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
if err != nil {
return nil, err
}
return transaction.NewClient(sqlTx), nil
}
func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
var data, dataNonce sql.RawBytes
var kid uint32
err := rows.Scan(&data, &dataNonce, &kid)
if err != nil {
return nil, err
}
if c.decryptor != nil && shouldDecrypt {
decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid)
if err != nil {
return nil, err
}
return decryptedData, nil
}
return data, nil
}
// Upsert used to be called upsertEncrypted in store package before move
func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
objBytes := toBytes(obj)
var dataNonce []byte
var err error
var kid uint32
if c.encryptor != nil && shouldEncrypt {
objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes)
if err != nil {
return err
}
}
return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid)
}
// toBytes encodes an object to a byte slice
func toBytes(obj any) []byte {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(obj)
if err != nil {
panic(fmt.Errorf("error while gobbing object: %w", err))
}
bb := buf.Bytes()
return bb
}
// fromBytes decodes an object from a byte slice
func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) {
dec := gob.NewDecoder(bytes.NewReader(buf))
singleResult := reflect.New(typ)
err := dec.DecodeValue(singleResult)
return singleResult, err
}
// closeRowsOnError closes the sql.Rows object and wraps errors if needed
func closeRowsOnError(rows Rows, err error) error {
ce := rows.Close()
if ce != nil {
return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce)
}
return err
}
// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently
// creates new files.
func (c *Client) NewConnection() error {
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
err := c.conn.Close()
if err != nil {
return err
}
}
err := os.RemoveAll(InformerObjectCacheDBPath)
if err != nil {
return err
}
// Set the permissions in advance, because we can't control them if
// the file is created by a sql.Open call instead.
if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil {
return nil
}
sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+
// open SQLite file in read-write mode, creating it if it does not exist
"mode=rwc&"+
// use the WAL journal mode for consistency and efficiency
"_pragma=journal_mode=wal&"+
// do not even attempt to attain durability. Database is thrown away at pod restart
"_pragma=synchronous=off&"+
// do check foreign keys and honor ON DELETE CASCADE
"_pragma=foreign_keys=on&"+
// if two transactions want to write at the same time, allow 2 minutes for the first to complete
// before baling out
"_pragma=busy_timeout=120000&"+
// default to IMMEDIATE mode for transactions. Setting this parameter is the only current way
// to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation
// of BeginTx
"_txlock=immediate")
if err != nil {
return err
}
c.conn = sqlDB
return nil
}
// This acts like "touch" for both existing files and non-existing files.
// permissions.
//
// It's created with the correct perms, and if the file already exists, it will
// be chmodded to the correct perms.
func touchFile(filename string, perms fs.FileMode) error {
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms)
if err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
return os.Chmod(filename, perms)
}

View File

@ -0,0 +1,667 @@
package db
import (
"context"
"database/sql"
"fmt"
"io/fs"
"math"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
// Mocks for this test are generated with the following command.
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
type testStoreObject struct {
Id string
Val string
}
func TestNewClient(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
expectedClient := &Client{
conn: c,
encryptor: e,
decryptor: d,
}
client, err := NewClient(c, e, d)
assert.Nil(t, err)
assert.Equal(t, expectedClient, client)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryForRows(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
client := SetupClient(t, c, nil, nil)
s := NewMockStmt(gomock.NewController(t))
ctx := context.TODO()
r := &sql.Rows{}
s.EXPECT().QueryContext(ctx).Return(r, nil)
rows, err := client.QueryForRows(ctx, s)
assert.Nil(t, err)
assert.Equal(t, r, rows)
},
})
tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
client := SetupClient(t, c, nil, nil)
s := NewMockStmt(gomock.NewController(t))
ctx := context.TODO()
s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error"))
_, err := client.QueryForRows(ctx, s)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryObjects(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
var keyId uint32 = math.MaxUint32
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.Nil(t, err)
assert.Equal(t, 1, len(items))
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(
testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.Nil(t, err)
assert.Equal(t, 0, len(items))
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryStrings(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadStrings(r)
assert.Nil(t, err)
assert.Equal(t, 1, len(items))
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadStrings(r)
assert.Nil(t, err)
assert.Equal(t, 0, len(items))
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestReadInt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testResult := 42
tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
p := a[0].(*int)
*p = testResult
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
result, err := client.ReadInt(r)
assert.Nil(t, err)
assert.Equal(t, 42, result)
},
})
tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
a[0] = testResult
})
r.EXPECT().Err().Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
a[0] = testResult
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.ErrorIs(t, err, sql.ErrNoRows)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestBegin(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), false)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), true)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.BeginTx(context.Background(), false)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestUpsert(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
var keyID uint32 = 5
// Tests with shouldEncryptSet to true
tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.Nil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
testObjBytes := toBytes(testObject)
e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) {
c := SetupMockConnection(t)
d := SetupMockDecryptor(t)
e := SetupMockEncryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
var testByteValue []byte
testObjBytes := toBytes(testObject)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, false)
assert.Nil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestPrepare(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
sqlStmt := &sql.Stmt{}
c.EXPECT().Prepare("something").Return(sqlStmt, nil)
stmt := client.Prepare("something")
assert.Equal(t, sqlStmt, stmt)
},
})
tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error"))
assert.Panics(t, func() { client.Prepare("something") })
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestNewConnection(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
c.EXPECT().Close().Return(nil)
err := client.NewConnection()
assert.Nil(t, err)
// Create a transaction to ensure that the file is written to disk.
txC, err := client.BeginTx(context.Background(), false)
assert.NoError(t, err)
assert.NoError(t, txC.Commit())
assert.FileExists(t, InformerObjectCacheDBPath)
assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600)
err = os.Remove(InformerObjectCacheDBPath)
if err != nil {
assert.Fail(t, "could not remove object cache path after test")
}
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestCommit(t *testing.T) {
}
func TestRollback(t *testing.T) {
}
func SetupMockConnection(t *testing.T) *MockConnection {
mockC := NewMockConnection(gomock.NewController(t))
return mockC
}
func SetupMockEncryptor(t *testing.T) *MockEncryptor {
mockE := NewMockEncryptor(gomock.NewController(t))
return mockE
}
func SetupMockDecryptor(t *testing.T) *MockDecryptor {
MockD := NewMockDecryptor(gomock.NewController(t))
return MockD
}
func SetupMockRows(t *testing.T) *MockRows {
MockR := NewMockRows(gomock.NewController(t))
return MockR
}
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client {
c, _ := NewClient(connection, encryptor, decryptor)
return c
}
func TestTouchFile(t *testing.T) {
t.Run("File doesn't exist before", func(t *testing.T) {
filename := filepath.Join(t.TempDir(), "test1.txt")
assert.NoError(t, touchFile(filename, 0600))
assertFileHasPermissions(t, filename, 0600)
})
t.Run("File exists with different permissions", func(t *testing.T) {
filename := filepath.Join(t.TempDir(), "test2.txt")
assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644))
assert.NoError(t, touchFile(filename, 0600))
assertFileHasPermissions(t, filename, 0600)
})
}
func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool {
t.Helper()
info, err := os.Lstat(fname)
if err != nil {
if os.IsNotExist(err) {
return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname))
}
return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err))
}
// Stringifying the perms makes it easier to read than a Hex comparison.
assert.Equal(t, wantPerms.String(), info.Mode().Perm().String())
return true
}

View File

@ -0,0 +1,370 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//
// Package db is a generated GoMock package.
package db
import (
context "context"
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockRows is a mock of Rows interface.
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
}
// MockRowsMockRecorder is the mock recorder for MockRows.
type MockRowsMockRecorder struct {
mock *MockRows
}
// NewMockRows creates a new mock instance.
func NewMockRows(ctrl *gomock.Controller) *MockRows {
mock := &MockRows{ctrl: ctrl}
mock.recorder = &MockRowsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRows) EXPECT() *MockRowsMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRows) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRowsMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close))
}
// Err mocks base method.
func (m *MockRows) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err.
func (mr *MockRowsMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err))
}
// Next mocks base method.
func (m *MockRows) Next() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next")
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next.
func (mr *MockRowsMockRecorder) Next() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next))
}
// Scan mocks base method.
func (m *MockRows) Scan(arg0 ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Scan indicates an expected call of Scan.
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
}
// MockConnection is a mock of Connection interface.
type MockConnection struct {
ctrl *gomock.Controller
recorder *MockConnectionMockRecorder
}
// MockConnectionMockRecorder is the mock recorder for MockConnection.
type MockConnectionMockRecorder struct {
mock *MockConnection
}
// NewMockConnection creates a new mock instance.
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
mock := &MockConnection{ctrl: ctrl}
mock.recorder = &MockConnectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
return m.recorder
}
// BeginTx mocks base method.
func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(*sql.Tx)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1)
}
// Close mocks base method.
func (m *MockConnection) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockConnectionMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close))
}
// Exec mocks base method.
func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
}
// Prepare mocks base method.
func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Prepare indicates an expected call of Prepare.
func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0)
}
// MockEncryptor is a mock of Encryptor interface.
type MockEncryptor struct {
ctrl *gomock.Controller
recorder *MockEncryptorMockRecorder
}
// MockEncryptorMockRecorder is the mock recorder for MockEncryptor.
type MockEncryptorMockRecorder struct {
mock *MockEncryptor
}
// NewMockEncryptor creates a new mock instance.
func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor {
mock := &MockEncryptor{ctrl: ctrl}
mock.recorder = &MockEncryptorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder {
return m.recorder
}
// Encrypt mocks base method.
func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encrypt", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].([]byte)
ret2, _ := ret[2].(uint32)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// Encrypt indicates an expected call of Encrypt.
func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0)
}
// MockDecryptor is a mock of Decryptor interface.
type MockDecryptor struct {
ctrl *gomock.Controller
recorder *MockDecryptorMockRecorder
}
// MockDecryptorMockRecorder is the mock recorder for MockDecryptor.
type MockDecryptorMockRecorder struct {
mock *MockDecryptor
}
// NewMockDecryptor creates a new mock instance.
func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor {
mock := &MockDecryptor{ctrl: ctrl}
mock.recorder = &MockDecryptorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder {
return m.recorder
}
// Decrypt mocks base method.
func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Decrypt indicates an expected call of Decrypt.
func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2)
}
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}

View File

@ -0,0 +1,90 @@
/*
Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures.
The use of this package simplifies testing for callers by making the underlying transaction mock-able.
*/
package transaction
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
// Client provides a way to interact with the underlying sql transaction.
type Client struct {
sqlTx SQLTx
}
// SQLTx represents a sql transaction
type SQLTx interface {
Exec(query string, args ...any) (sql.Result, error)
Stmt(stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
}
// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type
// because we are able to mock its outputs and do not need an actual connection.
type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
// NewClient returns a Client with the given transaction assigned.
func NewClient(tx SQLTx) *Client {
return &Client{sqlTx: tx}
}
// Commit commits the transaction and then unlocks the database.
func (c *Client) Commit() error {
return c.sqlTx.Commit()
}
// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec()
// returns an error.
func (c *Client) Exec(stmt string, args ...any) error {
_, err := c.sqlTx.Exec(stmt, args...)
if err != nil {
return c.rollback(c.sqlTx, err)
}
return nil
}
// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned
// here to aid in testing callers by providing a way to configure the statement's behavior.
func (c *Client) Stmt(stmt *sql.Stmt) Stmt {
s := c.sqlTx.Stmt(stmt)
return s
}
// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The
// transaction is rolled back if Stmt.Exec() returns an error.
func (c *Client) StmtExec(stmt Stmt, args ...any) error {
_, err := stmt.Exec(args...)
if err != nil {
return c.rollback(c.sqlTx, err)
}
return nil
}
// rollback handles rollbacks and wraps errors if needed
func (c *Client) rollback(tx SQLTx, err error) error {
rerr := tx.Rollback()
if rerr != nil {
return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr)
}
return errors.Wrapf(err, "Encountered error, successfully rolled back")
}
// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned
// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too
// late.
func (c *Client) Cancel() error {
rerr := c.sqlTx.Rollback()
if rerr != sql.ErrTxDone {
return rerr
}
return nil
}

View File

@ -0,0 +1,184 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//
// Package transaction is a generated GoMock package.
package transaction
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}

View File

@ -0,0 +1,182 @@
package transaction
import (
"database/sql"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
func TestNewClient(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
c := NewClient(tx)
assert.Equal(t, tx, c.sqlTx)
}
func TestCommit(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := &sql.Stmt{}
var returnedTXStmt *sql.Stmt
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Stmt(stmt).Return(returnedTXStmt)
c := &Client{
sqlTx: tx,
}
returnedStmt := c.Stmt(stmt)
// whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it
// won't be equal to an empty struct
assert.Equal(t, returnedTXStmt, returnedStmt)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmtExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error2"))
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

View File

@ -0,0 +1,184 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//
// Package db is a generated GoMock package.
package db
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}

View File

@ -0,0 +1,8 @@
package db
import "strings"
// Sanitize returns a string that can be used in SQL as a name
func Sanitize(s string) string {
return strings.ReplaceAll(s, "\"", "")
}

View File

@ -0,0 +1,168 @@
/*
Package encryption provides encryption and decryption functions, while
abstracting away key management concerns.
Uses AES-GCM encryption, with key rotation, keeping keys in memory.
*/
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"sync"
"github.com/pkg/errors"
)
var (
ErrKeyNotFound = errors.New("data key not found")
// maxWriteCount holds the maximum amount of times the active key can be
// used, prior to it being rotated. 2^32 is the currently recommended key
// wear-out params by NIST for AES-GCM using random nonces.
maxWriteCount int64 = 1 << 32
)
const (
keySize = 32 // 32 for AES-256
)
// Manager uses AES-GCM encryption and keeps in memory the data encryption
// keys. The active encryption key is automatically rotated once it has been
// used over a certain amount of times - defined by maxWriteCount.
type Manager struct {
dataKeys [][]byte
activeKeyCounter int64
// lock works as the mutual exclusion lock for dataKeys.
lock sync.RWMutex
// counterLock works as the mutual exclusion lock for activeKeyCounter.
counterLock sync.Mutex
}
// NewManager returns Manager, which satisfies db.Encryptor and db.Decryptor
func NewManager() (*Manager, error) {
m := &Manager{
dataKeys: [][]byte{},
}
m.newDataEncryptionKey()
return m, nil
}
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
func (m *Manager) Encrypt(data []byte) ([]byte, []byte, uint32, error) {
dek, keyID, err := m.fetchActiveDataKey()
if err != nil {
return nil, nil, 0, err
}
aead, err := createGCMCypher(dek)
if err != nil {
return nil, nil, 0, err
}
edata, nonce, err := encrypt(aead, data)
if err != nil {
return nil, nil, 0, err
}
return edata, nonce, keyID, nil
}
// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error.
func (m *Manager) Decrypt(edata, nonce []byte, keyID uint32) ([]byte, error) {
dek, err := m.key(keyID)
if err != nil {
return nil, err
}
aead, err := createGCMCypher(dek)
if err != nil {
return nil, errors.Wrap(err, "failed to create GCMCypher from DEK")
}
data, err := aead.Open(nil, nonce, edata, nil)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("failed to decrypt data using keyid %d", keyID))
}
return data, nil
}
func encrypt(aead cipher.AEAD, data []byte) ([]byte, []byte, error) {
if aead == nil {
return nil, nil, fmt.Errorf("aead is nil, cannot encrypt data")
}
nonce := make([]byte, aead.NonceSize())
_, err := rand.Read(nonce)
if err != nil {
return nil, nil, err
}
sealed := aead.Seal(nil, nonce, data, nil)
return sealed, nonce, nil
}
func createGCMCypher(key []byte) (cipher.AEAD, error) {
b, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(b)
if err != nil {
return nil, err
}
return aead, nil
}
// fetchActiveDataKey returns the current data key and its key ID.
// Each call results in activeKeyCounter being incremented by 1. When the
// the activeKeyCounter exceeds maxWriteCount, the active data key is
// rotated - before being returned.
func (m *Manager) fetchActiveDataKey() ([]byte, uint32, error) {
m.counterLock.Lock()
defer m.counterLock.Unlock()
m.activeKeyCounter++
if m.activeKeyCounter >= maxWriteCount {
return m.newDataEncryptionKey()
}
return m.activeKey()
}
func (m *Manager) newDataEncryptionKey() ([]byte, uint32, error) {
dek := make([]byte, keySize)
_, err := rand.Read(dek)
if err != nil {
return nil, 0, err
}
m.lock.Lock()
defer m.lock.Unlock()
m.activeKeyCounter = 1
m.dataKeys = append(m.dataKeys, dek)
keyID := uint32(len(m.dataKeys) - 1)
return dek, keyID, nil
}
func (m *Manager) activeKey() ([]byte, uint32, error) {
m.lock.RLock()
defer m.lock.RUnlock()
nk := len(m.dataKeys)
if nk == 0 {
return nil, 0, ErrKeyNotFound
}
keyID := uint32(nk - 1)
return m.dataKeys[keyID], keyID, nil
}
func (m *Manager) key(keyID uint32) ([]byte, error) {
m.lock.RLock()
defer m.lock.RUnlock()
if len(m.dataKeys) <= int(keyID) {
return nil, fmt.Errorf("%w: %v", ErrKeyNotFound, keyID)
}
return m.dataKeys[keyID], nil
}

View File

@ -0,0 +1,327 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewManager(t *testing.T) {
m, err := NewManager()
if err != nil {
t.FailNow()
}
assert.NotNil(t, m)
}
func TestEncrypt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "test encrypt with arbitrary initial key", test: func(t *testing.T) {
testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181}
m, err := NewManager()
require.Nil(t, err)
m.dataKeys[0] = testDEK
testData := []byte("something")
cipherText, nonce, keyID, err := m.Encrypt(testData)
require.Nil(t, err)
dek := m.dataKeys[keyID]
b, err := aes.NewCipher(dek)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
decryptedData, err := aead.Open(nil, nonce, cipherText, nil)
require.Nil(t, err)
assert.Equal(t, testData, decryptedData)
}})
tests = append(tests, testCase{description: "test encrypt without arbitrary initial key", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
cipherText, nonce, keyID, err := m.Encrypt(testData)
require.Nil(t, err)
dek := m.dataKeys[keyID]
b, err := aes.NewCipher(dek)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
decryptedData, err := aead.Open(nil, nonce, cipherText, nil)
require.Nil(t, err)
assert.Equal(t, testData, decryptedData)
}})
tests = append(tests, testCase{description: "test encrypt: same data yield different cipher/nonce pair", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
cipher1, nonce1, keyID1, err := m.Encrypt(testData)
require.Nil(t, err)
assert.Len(t, cipher1, 25)
assert.Len(t, nonce1, 12)
assert.NotEmpty(t, cipher1)
assert.NotEmpty(t, nonce1)
cipher2, nonce2, keyID2, err := m.Encrypt(testData)
require.Nil(t, err)
assert.Equal(t, keyID1, keyID2)
assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher")
assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce")
}})
tests = append(tests, testCase{description: "test encrypt with key rotation", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
cipher1, nonce1, keyID1, err := m.Encrypt(testData)
require.Nil(t, err)
assert.Len(t, cipher1, 25)
assert.Len(t, nonce1, 12)
assert.NotEmpty(t, cipher1)
assert.NotEmpty(t, nonce1)
m.activeKeyCounter += maxWriteCount
cipher2, nonce2, keyID2, err := m.Encrypt(testData)
require.Nil(t, err)
assert.Equal(t, int64(1), m.activeKeyCounter)
assert.NotEqual(t, keyID1, keyID2)
assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher")
assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce")
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestDecrypt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "test decrypt with arbitrary key", test: func(t *testing.T) {
testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181}
m, err := NewManager()
require.Nil(t, err)
m.dataKeys[0] = testDEK
testData := []byte("something")
// encrypt data out of band.
b, err := aes.NewCipher(testDEK)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
nonce := make([]byte, aead.NonceSize())
_, err = rand.Read(nonce)
require.Nil(t, err)
cipherText := aead.Seal(nil, nonce, testData, nil)
// use manager to decrypt the data.
decryptedData, err := m.Decrypt(cipherText, nonce, 0)
require.Nil(t, err)
assert.Equal(t, testData, decryptedData)
},
})
tests = append(tests, testCase{description: "test decrypt without arbitrary key", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
// encrypt data out of band.
dek := m.dataKeys[0]
b, err := aes.NewCipher(dek)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
nonce := make([]byte, aead.NonceSize())
_, err = rand.Read(nonce)
require.Nil(t, err)
cipherText := aead.Seal(nil, nonce, testData, nil)
// use manager to decrypt the data.
decryptedData, err := m.Decrypt(cipherText, nonce, 0)
require.Nil(t, err)
assert.Equal(t, testData, decryptedData)
},
})
tests = append(tests, testCase{description: "test decrypt with wrong data nonce should return error", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
// encrypt data out of band.
dek := m.dataKeys[0]
b, err := aes.NewCipher(dek)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
nonce := make([]byte, aead.NonceSize())
_, err = rand.Read(nonce)
require.Nil(t, err)
cipherText := aead.Seal(nil, nonce, testData, nil)
// generate random nonce.
randomNonce := make([]byte, aead.NonceSize())
_, err = rand.Read(nonce)
require.Nil(t, err)
// decrypted encrypted data using encrypted dek
_, err = m.Decrypt(cipherText, randomNonce, 0)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "test decrypt with DEK/nonce pair not used to encrypt should return error", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
// encrypt data out of band.
dek := m.dataKeys[0]
b, err := aes.NewCipher(dek)
require.Nil(t, err)
aead, err := cipher.NewGCM(b)
require.Nil(t, err)
nonce := make([]byte, aead.NonceSize())
_, err = rand.Read(nonce)
require.Nil(t, err)
cipherText := aead.Seal(nil, nonce, testData, nil)
key, id, err := m.newDataEncryptionKey()
require.Nil(t, err)
m.dataKeys[id] = key
plainText, err := m.Decrypt(cipherText, nonce, id)
assert.NotNil(t, err)
assert.Nil(t, plainText)
},
})
tests = append(tests, testCase{description: "test decrypt for non active key", test: func(t *testing.T) {
m, err := NewManager()
require.Nil(t, err)
testData := []byte("something")
cipher, nonce, keyID, err := m.Encrypt(testData)
require.Nil(t, err)
// force key rotation.
m.activeKeyCounter += maxWriteCount
_, _, newKeyID, err := m.Encrypt(nil)
require.Nil(t, err)
require.NotEqual(t, keyID, newKeyID)
// use manager to decrypt the data.
decryptedData, err := m.Decrypt(cipher, nonce, keyID)
require.Nil(t, err)
assert.Equal(t, testData, decryptedData)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
var buf = make([]byte, 8192)
func BenchmarkEncryption(b *testing.B) {
benchEncrypt(b, 1024)
benchEncrypt(b, 4096)
benchEncrypt(b, 8192)
}
func BenchmarkDecryption(b *testing.B) {
benchDecrypt(b, 1024)
benchDecrypt(b, 4096)
benchDecrypt(b, 8192)
}
func benchEncrypt(b *testing.B, size int) {
m, err := NewManager()
if err != nil {
b.Fatal("failed to create manager", err)
}
// disable auto rotation to avoid skewing results.
maxWriteCount = math.MaxInt32
b.Run(fmt.Sprintf("encrypt-%d", size), func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_, _, _, err := m.Encrypt(buf[:size])
if err != nil {
b.Fatal("error encrypting data", err)
}
}
})
}
func benchDecrypt(b *testing.B, size int) {
m, err := NewManager()
if err != nil {
b.Fatal("failed to create manager", err)
}
edata, enonce, kid, err := m.Encrypt(buf[:size])
if err != nil {
b.Fatal("failed to encrypt data", err)
}
b.Run(fmt.Sprintf("decrypt-%d", size), func(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
_, err := m.Decrypt(edata, enonce, kid)
if err != nil {
b.Fatal("error encrypting data", err)
}
}
})
}

View File

@ -0,0 +1,204 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows
//
// Package informer is a generated GoMock package.
package informer
import (
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}
// MockRows is a mock of Rows interface.
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
}
// MockRowsMockRecorder is the mock recorder for MockRows.
type MockRowsMockRecorder struct {
mock *MockRows
}
// NewMockRows creates a new mock instance.
func NewMockRows(ctrl *gomock.Controller) *MockRows {
mock := &MockRows{ctrl: ctrl}
mock.recorder = &MockRowsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRows) EXPECT() *MockRowsMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRows) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRowsMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close))
}
// Err mocks base method.
func (m *MockRows) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err.
func (mr *MockRowsMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err))
}
// Next mocks base method.
func (m *MockRows) Next() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next")
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next.
func (mr *MockRowsMockRecorder) Next() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next))
}
// Scan mocks base method.
func (m *MockRows) Scan(arg0 ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Scan indicates an expected call of Scan.
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
}

View File

@ -0,0 +1,237 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface
//
// Package informer is a generated GoMock package.
package informer
import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
types "k8s.io/apimachinery/pkg/types"
watch "k8s.io/apimachinery/pkg/watch"
)
// MockResourceInterface is a mock of ResourceInterface interface.
type MockResourceInterface struct {
ctrl *gomock.Controller
recorder *MockResourceInterfaceMockRecorder
}
// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface.
type MockResourceInterfaceMockRecorder struct {
mock *MockResourceInterface
}
// NewMockResourceInterface creates a new mock instance.
func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface {
mock := &MockResourceInterface{ctrl: ctrl}
mock.recorder = &MockResourceInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder {
return m.recorder
}
// Apply mocks base method.
func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2, arg3}
for _, a := range arg4 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Apply", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Apply indicates an expected call of Apply.
func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...)
}
// ApplyStatus mocks base method.
func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ApplyStatus indicates an expected call of ApplyStatus.
func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3)
}
// Create mocks base method.
func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Create", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Create indicates an expected call of Create.
func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...)
}
// Delete mocks base method.
func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Delete", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...)
}
// DeleteCollection mocks base method.
func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteCollection indicates an expected call of DeleteCollection.
func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2)
}
// Get mocks base method.
func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Get", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...)
}
// List mocks base method.
func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", arg0, arg1)
ret0, _ := ret[0].(*unstructured.UnstructuredList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1)
}
// Patch mocks base method.
func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2, arg3, arg4}
for _, a := range arg5 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Patch", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Patch indicates an expected call of Patch.
func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...)
}
// Update mocks base method.
func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Update", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...)
}
// UpdateStatus mocks base method.
func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateStatus indicates an expected call of UpdateStatus.
func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2)
}
// Watch mocks base method.
func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Watch", arg0, arg1)
ret0, _ := ret[0].(watch.Interface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Watch indicates an expected call of Watch.
func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1)
}

View File

@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient
//
// Package factory is a generated GoMock package.
package factory
import (
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}

View File

@ -0,0 +1,237 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface
//
// Package factory is a generated GoMock package.
package factory
import (
context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
types "k8s.io/apimachinery/pkg/types"
watch "k8s.io/apimachinery/pkg/watch"
)
// MockResourceInterface is a mock of ResourceInterface interface.
type MockResourceInterface struct {
ctrl *gomock.Controller
recorder *MockResourceInterfaceMockRecorder
}
// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface.
type MockResourceInterfaceMockRecorder struct {
mock *MockResourceInterface
}
// NewMockResourceInterface creates a new mock instance.
func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface {
mock := &MockResourceInterface{ctrl: ctrl}
mock.recorder = &MockResourceInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder {
return m.recorder
}
// Apply mocks base method.
func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2, arg3}
for _, a := range arg4 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Apply", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Apply indicates an expected call of Apply.
func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...)
}
// ApplyStatus mocks base method.
func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ApplyStatus indicates an expected call of ApplyStatus.
func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3)
}
// Create mocks base method.
func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Create", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Create indicates an expected call of Create.
func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...)
}
// Delete mocks base method.
func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Delete", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...)
}
// DeleteCollection mocks base method.
func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteCollection indicates an expected call of DeleteCollection.
func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2)
}
// Get mocks base method.
func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Get", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...)
}
// List mocks base method.
func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List", arg0, arg1)
ret0, _ := ret[0].(*unstructured.UnstructuredList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List.
func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1)
}
// Patch mocks base method.
func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2, arg3, arg4}
for _, a := range arg5 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Patch", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Patch indicates an expected call of Patch.
func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...)
}
// Update mocks base method.
func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2}
for _, a := range arg3 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Update", varargs...)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Update indicates an expected call of Update.
func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2}, arg3...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...)
}
// UpdateStatus mocks base method.
func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2)
ret0, _ := ret[0].(*unstructured.Unstructured)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateStatus indicates an expected call of UpdateStatus.
func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2)
}
// Watch mocks base method.
func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Watch", arg0, arg1)
ret0, _ := ret[0].(watch.Interface)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Watch indicates an expected call of Watch.
func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1)
}

View File

@ -0,0 +1,179 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/informer/factory (interfaces: DBClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient
//
// Package factory is a generated GoMock package.
package factory
import (
context "context"
sql "database/sql"
reflect "reflect"
db "github.com/rancher/steve/pkg/sqlcache/db"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockDBClient is a mock of DBClient interface.
type MockDBClient struct {
ctrl *gomock.Controller
recorder *MockDBClientMockRecorder
}
// MockDBClientMockRecorder is the mock recorder for MockDBClient.
type MockDBClientMockRecorder struct {
mock *MockDBClient
}
// NewMockDBClient creates a new mock instance.
func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient {
mock := &MockDBClient{ctrl: ctrl}
mock.recorder = &MockDBClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder {
return m.recorder
}
// BeginTx mocks base method.
func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(db.TXClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1)
}
// CloseStmt mocks base method.
func (m *MockDBClient) CloseStmt(arg0 db.Closable) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseStmt", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CloseStmt indicates an expected call of CloseStmt.
func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0)
}
// NewConnection mocks base method.
func (m *MockDBClient) NewConnection() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NewConnection")
ret0, _ := ret[0].(error)
return ret0
}
// NewConnection indicates an expected call of NewConnection.
func (mr *MockDBClientMockRecorder) NewConnection() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockDBClient)(nil).NewConnection))
}
// Prepare mocks base method.
func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Prepare indicates an expected call of Prepare.
func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0)
}
// QueryForRows mocks base method.
func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryForRows", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryForRows indicates an expected call of QueryForRows.
func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...)
}
// ReadInt mocks base method.
func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadInt", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadInt indicates an expected call of ReadInt.
func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0)
}
// ReadObjects mocks base method.
func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2)
ret0, _ := ret[0].([]any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadObjects indicates an expected call of ReadObjects.
func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2)
}
// ReadStrings mocks base method.
func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadStrings", arg0)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadStrings indicates an expected call of ReadStrings.
func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0)
}
// Upsert mocks base method.
func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
return ret0
}
// Upsert indicates an expected call of Upsert.
func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4)
}

View File

@ -0,0 +1,186 @@
/*
Package factory provides a cache factory for the sql-based cache.
*/
package factory
import (
"fmt"
"os"
"sync"
"time"
"github.com/rancher/lasso/pkg/log"
"github.com/rancher/steve/pkg/sqlcache/db"
"github.com/rancher/steve/pkg/sqlcache/encryption"
"github.com/rancher/steve/pkg/sqlcache/informer"
sqlStore "github.com/rancher/steve/pkg/sqlcache/store"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/tools/cache"
)
// EncryptAllEnvVar is set to "true" if users want all types' data blobs to be encrypted in SQLite
// otherwise only variables in defaultEncryptedResourceTypes will have their blobs encrypted
const EncryptAllEnvVar = "CATTLE_ENCRYPT_CACHE_ALL"
// CacheFactory builds Informer instances and keeps a cache of instances it created
type CacheFactory struct {
wg wait.Group
dbClient DBClient
stopCh chan struct{}
mutex sync.RWMutex
encryptAll bool
newInformer newInformer
informers map[schema.GroupVersionKind]*guardedInformer
informersMutex sync.Mutex
}
type guardedInformer struct {
informer *informer.Informer
mutex *sync.Mutex
}
type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespace bool) (*informer.Informer, error)
type DBClient interface {
informer.DBClient
sqlStore.DBClient
connector
}
type Cache struct {
informer.ByOptionsLister
}
type connector interface {
NewConnection() error
}
var defaultEncryptedResourceTypes = map[schema.GroupVersionKind]struct{}{
{
Version: "v1",
Kind: "Secret",
}: {},
}
// NewCacheFactory returns an informer factory instance
func NewCacheFactory() (*CacheFactory, error) {
m, err := encryption.NewManager()
if err != nil {
return nil, err
}
dbClient, err := db.NewClient(nil, m, m)
if err != nil {
return nil, err
}
return &CacheFactory{
wg: wait.Group{},
stopCh: make(chan struct{}),
encryptAll: os.Getenv(EncryptAllEnvVar) == "true",
dbClient: dbClient,
newInformer: informer.NewInformer,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}, nil
}
// CacheFor returns an informer for given GVK, using sql store indexed with fields, using the specified client. For virtual fields, they must be added by the transform function
// and specified by fields to be used for later fields.
func (f *CacheFactory) CacheFor(fields [][]string, transform cache.TransformFunc, client dynamic.ResourceInterface, gvk schema.GroupVersionKind, namespaced bool, watchable bool) (Cache, error) {
// First of all block Reset() until we are done
f.mutex.RLock()
defer f.mutex.RUnlock()
// Second, check if the informer and its accompanying informer-specific mutex exist already in the informers cache
// If not, start by creating such informer-specific mutex. That is used later to ensure no two goroutines create
// informers for the same GVK at the same type
f.informersMutex.Lock()
// Note: the informers cache is protected by informersMutex, which we don't want to hold for very long because
// that blocks CacheFor for other GVKs, hence not deferring unlock here
gi, ok := f.informers[gvk]
if !ok {
gi = &guardedInformer{
informer: nil,
mutex: &sync.Mutex{},
}
f.informers[gvk] = gi
}
f.informersMutex.Unlock()
// At this point an informer-specific mutex (gi.mutex) is guaranteed to exist. Lock it
gi.mutex.Lock()
defer gi.mutex.Unlock()
// Then: if the informer really was not created yet (first time here or previous times have errored out)
// actually create the informer
if gi.informer == nil {
start := time.Now()
log.Debugf("CacheFor STARTS creating informer for %v", gvk)
defer func() {
log.Debugf("CacheFor IS DONE creating informer for %v (took %v)", gvk, time.Now().Sub(start))
}()
_, encryptResourceAlways := defaultEncryptedResourceTypes[gvk]
shouldEncrypt := f.encryptAll || encryptResourceAlways
i, err := f.newInformer(client, fields, transform, gvk, f.dbClient, shouldEncrypt, namespaced)
if err != nil {
return Cache{}, err
}
err = i.SetWatchErrorHandler(func(r *cache.Reflector, err error) {
if !watchable && errors.IsMethodNotSupported(err) {
// expected, continue without logging
return
}
cache.DefaultWatchErrorHandler(r, err)
})
if err != nil {
return Cache{}, err
}
f.wg.StartWithChannel(f.stopCh, i.Run)
gi.informer = i
}
if !cache.WaitForCacheSync(f.stopCh, gi.informer.HasSynced) {
return Cache{}, fmt.Errorf("failed to sync SQLite Informer cache for GVK %v", gvk)
}
// At this point the informer is ready, return it
return Cache{ByOptionsLister: gi.informer}, nil
}
// Reset closes the stopCh which stops any running informers, assigns a new stopCh, resets the GVK-informer cache, and resets
// the database connection which wipes any current sqlite database at the default location.
func (f *CacheFactory) Reset() error {
if f.dbClient == nil {
// nothing to reset
return nil
}
// first of all wait until all CacheFor() calls that create new informers are finished. Also block any new ones
f.mutex.Lock()
defer f.mutex.Unlock()
// now that we are alone, stop all informers created until this point
close(f.stopCh)
f.stopCh = make(chan struct{})
f.wg.Wait()
// and get rid of all references to those informers and their mutexes
f.informersMutex.Lock()
defer f.informersMutex.Unlock()
f.informers = make(map[schema.GroupVersionKind]*guardedInformer)
// finally, reset the DB connection
err := f.dbClient.NewConnection()
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,287 @@
package factory
import (
"os"
"testing"
"time"
"github.com/rancher/steve/pkg/sqlcache/informer"
sqlStore "github.com/rancher/steve/pkg/sqlcache/store"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/tools/cache"
)
//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient
//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient
//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface
//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer
func TestNewCacheFactory(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned, should return no errors", test: func(t *testing.T) {
f, err := NewCacheFactory()
assert.Nil(t, err)
assert.NotNil(t, f.dbClient)
assert.False(t, f.encryptAll)
}})
tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned and EncryptAllEnvVar set to true, should return no errors and have encryptAll set to true", test: func(t *testing.T) {
err := os.Setenv(EncryptAllEnvVar, "true")
assert.Nil(t, err)
f, err := NewCacheFactory()
assert.Nil(t, err)
assert.Nil(t, err)
assert.NotNil(t, f.dbClient)
assert.True(t, f.encryptAll)
}})
// cannot run as parallel because tests involve changing env var
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestCacheFor(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh not closed, should return no error and should call Informer.Run(). A subsequent call to CacheFor() should return same informer", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
expectedGVK := schema.GroupVersionKind{}
sii := NewMockSharedIndexInformer(gomock.NewController(t))
sii.EXPECT().HasSynced().Return(true).AnyTimes()
sii.EXPECT().Run(gomock.Any()).MinTimes(1)
sii.EXPECT().SetWatchErrorHandler(gomock.Any())
i := &informer.Informer{
// need to set this so Run function is not nil
SharedIndexInformer: sii,
}
expectedC := Cache{
ByOptionsLister: i,
}
testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) {
assert.Equal(t, client, dynamicClient)
assert.Equal(t, fields, fields)
assert.Equal(t, expectedGVK, gvk)
assert.Equal(t, db, dbClient)
assert.Equal(t, false, shouldEncrypt)
return i, nil
}
f := &CacheFactory{
dbClient: dbClient,
stopCh: make(chan struct{}),
newInformer: testNewInformer,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}
go func() {
// this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually
time.Sleep(5 * time.Second)
close(f.stopCh)
}()
var c Cache
var err error
c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true)
assert.Nil(t, err)
assert.Equal(t, expectedC, c)
// this sleep is critical to the test. It ensure there has been enough time for expected function like Run to be invoked in their go routines.
time.Sleep(1 * time.Second)
c2, err := f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true)
assert.Nil(t, err)
assert.Equal(t, c, c2)
}})
tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning false, and stopCh not closed, should call Run() and return an error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
expectedGVK := schema.GroupVersionKind{}
sii := NewMockSharedIndexInformer(gomock.NewController(t))
sii.EXPECT().HasSynced().Return(false).AnyTimes()
sii.EXPECT().Run(gomock.Any())
sii.EXPECT().SetWatchErrorHandler(gomock.Any())
expectedI := &informer.Informer{
// need to set this so Run function is not nil
SharedIndexInformer: sii,
}
testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) {
assert.Equal(t, client, dynamicClient)
assert.Equal(t, fields, fields)
assert.Equal(t, expectedGVK, gvk)
assert.Equal(t, db, dbClient)
assert.Equal(t, false, shouldEncrypt)
return expectedI, nil
}
f := &CacheFactory{
dbClient: dbClient,
stopCh: make(chan struct{}),
newInformer: testNewInformer,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}
go func() {
time.Sleep(1 * time.Second)
close(f.stopCh)
}()
var err error
_, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true)
assert.NotNil(t, err)
time.Sleep(2 * time.Second)
}})
tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh closed, should not call Run() more than once and not return an error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
expectedGVK := schema.GroupVersionKind{}
sii := NewMockSharedIndexInformer(gomock.NewController(t))
sii.EXPECT().HasSynced().Return(true).AnyTimes()
// may or may not call run initially
sii.EXPECT().Run(gomock.Any()).MaxTimes(1)
sii.EXPECT().SetWatchErrorHandler(gomock.Any())
i := &informer.Informer{
// need to set this so Run function is not nil
SharedIndexInformer: sii,
}
expectedC := Cache{
ByOptionsLister: i,
}
testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) {
assert.Equal(t, client, dynamicClient)
assert.Equal(t, fields, fields)
assert.Equal(t, expectedGVK, gvk)
assert.Equal(t, db, dbClient)
assert.Equal(t, false, shouldEncrypt)
return i, nil
}
f := &CacheFactory{
dbClient: dbClient,
stopCh: make(chan struct{}),
newInformer: testNewInformer,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}
close(f.stopCh)
var c Cache
var err error
c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true)
assert.Nil(t, err)
assert.Equal(t, expectedC, c)
time.Sleep(1 * time.Second)
}})
tests = append(tests, testCase{description: "CacheFor() with no errors returned and encryptAll set to true, should return no error and pass shouldEncrypt as true to newInformer func", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
expectedGVK := schema.GroupVersionKind{}
sii := NewMockSharedIndexInformer(gomock.NewController(t))
sii.EXPECT().HasSynced().Return(true)
sii.EXPECT().Run(gomock.Any()).MinTimes(1).AnyTimes()
sii.EXPECT().SetWatchErrorHandler(gomock.Any())
i := &informer.Informer{
// need to set this so Run function is not nil
SharedIndexInformer: sii,
}
expectedC := Cache{
ByOptionsLister: i,
}
testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) {
assert.Equal(t, client, dynamicClient)
assert.Equal(t, fields, fields)
assert.Equal(t, expectedGVK, gvk)
assert.Equal(t, db, dbClient)
assert.Equal(t, true, shouldEncrypt)
return i, nil
}
f := &CacheFactory{
dbClient: dbClient,
stopCh: make(chan struct{}),
newInformer: testNewInformer,
encryptAll: true,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}
go func() {
time.Sleep(10 * time.Second)
close(f.stopCh)
}()
var c Cache
var err error
c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true)
assert.Nil(t, err)
assert.Equal(t, expectedC, c)
time.Sleep(1 * time.Second)
}})
tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, stopCh not closed, and transform func should return no error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
expectedGVK := schema.GroupVersionKind{}
sii := NewMockSharedIndexInformer(gomock.NewController(t))
sii.EXPECT().HasSynced().Return(true)
sii.EXPECT().Run(gomock.Any()).MinTimes(1)
sii.EXPECT().SetWatchErrorHandler(gomock.Any())
transformFunc := func(input interface{}) (interface{}, error) {
return "someoutput", nil
}
i := &informer.Informer{
// need to set this so Run function is not nil
SharedIndexInformer: sii,
}
expectedC := Cache{
ByOptionsLister: i,
}
testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) {
// we can't test func == func, so instead we check if the output was as expected
input := "someinput"
ouput, err := transform(input)
assert.Nil(t, err)
outputStr, ok := ouput.(string)
assert.True(t, ok, "ouput from transform was expected to be a string")
assert.Equal(t, "someoutput", outputStr)
assert.Equal(t, client, dynamicClient)
assert.Equal(t, fields, fields)
assert.Equal(t, expectedGVK, gvk)
assert.Equal(t, db, dbClient)
assert.Equal(t, false, shouldEncrypt)
return i, nil
}
f := &CacheFactory{
dbClient: dbClient,
stopCh: make(chan struct{}),
newInformer: testNewInformer,
informers: map[schema.GroupVersionKind]*guardedInformer{},
}
go func() {
// this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually
time.Sleep(5 * time.Second)
close(f.stopCh)
}()
var c Cache
var err error
c, err = f.CacheFor(fields, transformFunc, dynamicClient, expectedGVK, false, true)
assert.Nil(t, err)
assert.Equal(t, expectedC, c)
time.Sleep(1 * time.Second)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

View File

@ -0,0 +1,223 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: k8s.io/client-go/tools/cache (interfaces: SharedIndexInformer)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer
//
// Package factory is a generated GoMock package.
package factory
import (
reflect "reflect"
time "time"
gomock "go.uber.org/mock/gomock"
cache "k8s.io/client-go/tools/cache"
)
// MockSharedIndexInformer is a mock of SharedIndexInformer interface.
type MockSharedIndexInformer struct {
ctrl *gomock.Controller
recorder *MockSharedIndexInformerMockRecorder
}
// MockSharedIndexInformerMockRecorder is the mock recorder for MockSharedIndexInformer.
type MockSharedIndexInformerMockRecorder struct {
mock *MockSharedIndexInformer
}
// NewMockSharedIndexInformer creates a new mock instance.
func NewMockSharedIndexInformer(ctrl *gomock.Controller) *MockSharedIndexInformer {
mock := &MockSharedIndexInformer{ctrl: ctrl}
mock.recorder = &MockSharedIndexInformerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSharedIndexInformer) EXPECT() *MockSharedIndexInformerMockRecorder {
return m.recorder
}
// AddEventHandler mocks base method.
func (m *MockSharedIndexInformer) AddEventHandler(arg0 cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddEventHandler", arg0)
ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddEventHandler indicates an expected call of AddEventHandler.
func (mr *MockSharedIndexInformerMockRecorder) AddEventHandler(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandler), arg0)
}
// AddEventHandlerWithResyncPeriod mocks base method.
func (m *MockSharedIndexInformer) AddEventHandlerWithResyncPeriod(arg0 cache.ResourceEventHandler, arg1 time.Duration) (cache.ResourceEventHandlerRegistration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddEventHandlerWithResyncPeriod", arg0, arg1)
ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddEventHandlerWithResyncPeriod indicates an expected call of AddEventHandlerWithResyncPeriod.
func (mr *MockSharedIndexInformerMockRecorder) AddEventHandlerWithResyncPeriod(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandlerWithResyncPeriod", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandlerWithResyncPeriod), arg0, arg1)
}
// AddIndexers mocks base method.
func (m *MockSharedIndexInformer) AddIndexers(arg0 cache.Indexers) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddIndexers", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddIndexers indicates an expected call of AddIndexers.
func (mr *MockSharedIndexInformerMockRecorder) AddIndexers(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIndexers", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddIndexers), arg0)
}
// GetController mocks base method.
func (m *MockSharedIndexInformer) GetController() cache.Controller {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetController")
ret0, _ := ret[0].(cache.Controller)
return ret0
}
// GetController indicates an expected call of GetController.
func (mr *MockSharedIndexInformerMockRecorder) GetController() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetController", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetController))
}
// GetIndexer mocks base method.
func (m *MockSharedIndexInformer) GetIndexer() cache.Indexer {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIndexer")
ret0, _ := ret[0].(cache.Indexer)
return ret0
}
// GetIndexer indicates an expected call of GetIndexer.
func (mr *MockSharedIndexInformerMockRecorder) GetIndexer() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndexer", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetIndexer))
}
// GetStore mocks base method.
func (m *MockSharedIndexInformer) GetStore() cache.Store {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStore")
ret0, _ := ret[0].(cache.Store)
return ret0
}
// GetStore indicates an expected call of GetStore.
func (mr *MockSharedIndexInformerMockRecorder) GetStore() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetStore))
}
// HasSynced mocks base method.
func (m *MockSharedIndexInformer) HasSynced() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasSynced")
ret0, _ := ret[0].(bool)
return ret0
}
// HasSynced indicates an expected call of HasSynced.
func (mr *MockSharedIndexInformerMockRecorder) HasSynced() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSynced", reflect.TypeOf((*MockSharedIndexInformer)(nil).HasSynced))
}
// IsStopped mocks base method.
func (m *MockSharedIndexInformer) IsStopped() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsStopped")
ret0, _ := ret[0].(bool)
return ret0
}
// IsStopped indicates an expected call of IsStopped.
func (mr *MockSharedIndexInformerMockRecorder) IsStopped() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStopped", reflect.TypeOf((*MockSharedIndexInformer)(nil).IsStopped))
}
// LastSyncResourceVersion mocks base method.
func (m *MockSharedIndexInformer) LastSyncResourceVersion() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LastSyncResourceVersion")
ret0, _ := ret[0].(string)
return ret0
}
// LastSyncResourceVersion indicates an expected call of LastSyncResourceVersion.
func (mr *MockSharedIndexInformerMockRecorder) LastSyncResourceVersion() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastSyncResourceVersion", reflect.TypeOf((*MockSharedIndexInformer)(nil).LastSyncResourceVersion))
}
// RemoveEventHandler mocks base method.
func (m *MockSharedIndexInformer) RemoveEventHandler(arg0 cache.ResourceEventHandlerRegistration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveEventHandler", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveEventHandler indicates an expected call of RemoveEventHandler.
func (mr *MockSharedIndexInformerMockRecorder) RemoveEventHandler(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).RemoveEventHandler), arg0)
}
// Run mocks base method.
func (m *MockSharedIndexInformer) Run(arg0 <-chan struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Run", arg0)
}
// Run indicates an expected call of Run.
func (mr *MockSharedIndexInformerMockRecorder) Run(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSharedIndexInformer)(nil).Run), arg0)
}
// SetTransform mocks base method.
func (m *MockSharedIndexInformer) SetTransform(arg0 cache.TransformFunc) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetTransform", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetTransform indicates an expected call of SetTransform.
func (mr *MockSharedIndexInformerMockRecorder) SetTransform(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransform", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetTransform), arg0)
}
// SetWatchErrorHandler mocks base method.
func (m *MockSharedIndexInformer) SetWatchErrorHandler(arg0 cache.WatchErrorHandler) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetWatchErrorHandler", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// SetWatchErrorHandler indicates an expected call of SetWatchErrorHandler.
func (mr *MockSharedIndexInformerMockRecorder) SetWatchErrorHandler(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWatchErrorHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetWatchErrorHandler), arg0)
}

View File

@ -0,0 +1,263 @@
package informer
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
"github.com/rancher/steve/pkg/sqlcache/db"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
"k8s.io/client-go/tools/cache"
)
const (
selectQueryFmt = `
SELECT object, objectnonce, dekid FROM "%[1]s"
WHERE key IN (
SELECT key FROM "%[1]s_indices"
WHERE name = ? AND value IN (?%s)
)
`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%[1]s_indices" (
name TEXT NOT NULL,
value TEXT NOT NULL,
key TEXT NOT NULL REFERENCES "%[1]s"(key) ON DELETE CASCADE,
PRIMARY KEY (name, value, key)
)`
createIndexFmt = `CREATE INDEX IF NOT EXISTS "%[1]s_indices_index" ON "%[1]s_indices"(name, value)`
deleteIndicesFmt = `DELETE FROM "%s_indices" WHERE key = ?`
addIndexFmt = `INSERT INTO "%s_indices" (name, value, key) VALUES (?, ?, ?) ON CONFLICT DO NOTHING`
listByIndexFmt = `SELECT object, objectnonce, dekid FROM "%[1]s"
WHERE key IN (
SELECT key FROM "%[1]s_indices"
WHERE name = ? AND value = ?
)`
listKeyByIndexFmt = `SELECT DISTINCT key FROM "%s_indices" WHERE name = ? AND value = ?`
listIndexValuesFmt = `SELECT DISTINCT value FROM "%s_indices" WHERE name = ?`
)
// Indexer is a SQLite-backed cache.Indexer which builds upon Store adding an index table
type Indexer struct {
Store
indexers cache.Indexers
indexersLock sync.RWMutex
deleteIndicesQuery string
addIndexQuery string
listByIndexQuery string
listKeysByIndexQuery string
listIndexValuesQuery string
deleteIndicesStmt *sql.Stmt
addIndexStmt *sql.Stmt
listByIndexStmt *sql.Stmt
listKeysByIndexStmt *sql.Stmt
listIndexValuesStmt *sql.Stmt
}
var _ cache.Indexer = (*Indexer)(nil)
type Store interface {
DBClient
cache.Store
GetByKey(key string) (item any, exists bool, err error)
GetName() string
RegisterAfterUpsert(f func(key string, obj any, tx db.TXClient) error)
RegisterAfterDelete(f func(key string, tx db.TXClient) error)
GetShouldEncrypt() bool
GetType() reflect.Type
}
type DBClient interface {
BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error)
QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error)
ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error)
ReadStrings(rows db.Rows) ([]string, error)
ReadInt(rows db.Rows) (int, error)
Prepare(stmt string) *sql.Stmt
CloseStmt(stmt db.Closable) error
}
// NewIndexer returns a cache.Indexer backed by SQLite for objects of the given example type
func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) {
tx, err := s.BeginTx(context.Background(), true)
if err != nil {
return nil, err
}
createTableQuery := fmt.Sprintf(createTableFmt, db.Sanitize(s.GetName()))
err = tx.Exec(createTableQuery)
if err != nil {
return nil, &db.QueryError{QueryString: createTableQuery, Err: err}
}
createIndexQuery := fmt.Sprintf(createIndexFmt, db.Sanitize(s.GetName()))
err = tx.Exec(createIndexQuery)
if err != nil {
return nil, &db.QueryError{QueryString: createIndexQuery, Err: err}
}
err = tx.Commit()
if err != nil {
return nil, err
}
i := &Indexer{
Store: s,
indexers: indexers,
}
i.RegisterAfterUpsert(i.AfterUpsert)
i.deleteIndicesQuery = fmt.Sprintf(deleteIndicesFmt, db.Sanitize(s.GetName()))
i.addIndexQuery = fmt.Sprintf(addIndexFmt, db.Sanitize(s.GetName()))
i.listByIndexQuery = fmt.Sprintf(listByIndexFmt, db.Sanitize(s.GetName()))
i.listKeysByIndexQuery = fmt.Sprintf(listKeyByIndexFmt, db.Sanitize(s.GetName()))
i.listIndexValuesQuery = fmt.Sprintf(listIndexValuesFmt, db.Sanitize(s.GetName()))
i.deleteIndicesStmt = s.Prepare(i.deleteIndicesQuery)
i.addIndexStmt = s.Prepare(i.addIndexQuery)
i.listByIndexStmt = s.Prepare(i.listByIndexQuery)
i.listKeysByIndexStmt = s.Prepare(i.listKeysByIndexQuery)
i.listIndexValuesStmt = s.Prepare(i.listIndexValuesQuery)
return i, nil
}
/* Core methods */
// AfterUpsert updates indices of an object
func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error {
// delete all
err := tx.StmtExec(tx.Stmt(i.deleteIndicesStmt), key)
if err != nil {
return &db.QueryError{QueryString: i.deleteIndicesQuery, Err: err}
}
// re-insert all
i.indexersLock.RLock()
defer i.indexersLock.RUnlock()
for indexName, indexFunc := range i.indexers {
values, err := indexFunc(obj)
if err != nil {
return err
}
for _, value := range values {
err = tx.StmtExec(tx.Stmt(i.addIndexStmt), indexName, value, key)
if err != nil {
return &db.QueryError{QueryString: i.addIndexQuery, Err: err}
}
}
}
return nil
}
/* Satisfy cache.Indexer */
// Index returns a list of items that match the given object on the index function
func (i *Indexer) Index(indexName string, obj any) ([]any, error) {
i.indexersLock.RLock()
defer i.indexersLock.RUnlock()
indexFunc := i.indexers[indexName]
if indexFunc == nil {
return nil, fmt.Errorf("index with name %s does not exist", indexName)
}
values, err := indexFunc(obj)
if err != nil {
return nil, err
}
if len(values) == 0 {
return nil, nil
}
// typical case
if len(values) == 1 {
return i.ByIndex(indexName, values[0])
}
// atypical case - more than one value to lookup
// HACK: sql.Statement.Query does not allow to pass slices in as of go 1.19 - create an ad-hoc statement
query := fmt.Sprintf(selectQueryFmt, db.Sanitize(i.GetName()), strings.Repeat(", ?", len(values)-1))
stmt := i.Prepare(query)
defer i.CloseStmt(stmt)
// HACK: Query will accept []any but not []string
params := []any{indexName}
for _, value := range values {
params = append(params, value)
}
rows, err := i.QueryForRows(context.TODO(), stmt, params...)
if err != nil {
return nil, &db.QueryError{QueryString: query, Err: err}
}
return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt())
}
// ByIndex returns the stored objects whose set of indexed values
// for the named index includes the given indexed value
func (i *Indexer) ByIndex(indexName, indexedValue string) ([]any, error) {
rows, err := i.QueryForRows(context.TODO(), i.listByIndexStmt, indexName, indexedValue)
if err != nil {
return nil, &db.QueryError{QueryString: i.listByIndexQuery, Err: err}
}
return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt())
}
// IndexKeys returns a list of the Store keys of the objects whose indexed values in the given index include the given indexed value
func (i *Indexer) IndexKeys(indexName, indexedValue string) ([]string, error) {
i.indexersLock.RLock()
defer i.indexersLock.RUnlock()
indexFunc := i.indexers[indexName]
if indexFunc == nil {
return nil, fmt.Errorf("Index with name %s does not exist", indexName)
}
rows, err := i.QueryForRows(context.TODO(), i.listKeysByIndexStmt, indexName, indexedValue)
if err != nil {
return nil, &db.QueryError{QueryString: i.listKeysByIndexQuery, Err: err}
}
return i.ReadStrings(rows)
}
// ListIndexFuncValues wraps safeListIndexFuncValues and panics in case of I/O errors
func (i *Indexer) ListIndexFuncValues(name string) []string {
result, err := i.safeListIndexFuncValues(name)
if err != nil {
panic(fmt.Errorf("unexpected error in safeListIndexFuncValues: %w", err))
}
return result
}
// safeListIndexFuncValues returns all the indexed values of the given index
func (i *Indexer) safeListIndexFuncValues(indexName string) ([]string, error) {
rows, err := i.QueryForRows(context.TODO(), i.listIndexValuesStmt, indexName)
if err != nil {
return nil, &db.QueryError{QueryString: i.listIndexValuesQuery, Err: err}
}
return i.ReadStrings(rows)
}
// GetIndexers returns the indexers
func (i *Indexer) GetIndexers() cache.Indexers {
i.indexersLock.RLock()
defer i.indexersLock.RUnlock()
return i.indexers
}
// AddIndexers adds more indexers to this Store. If you call this after you already have data
// in the Store, the results are undefined.
func (i *Indexer) AddIndexers(newIndexers cache.Indexers) error {
i.indexersLock.Lock()
defer i.indexersLock.Unlock()
if i.indexers == nil {
i.indexers = make(map[string]cache.IndexFunc)
}
for k, v := range newIndexers {
i.indexers[k] = v
}
return nil
}

View File

@ -0,0 +1,614 @@
/*
Copyright 2023 SUSE LLC
Adapted from client-go, Copyright 2014 The Kubernetes Authors.
*/
package informer
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"k8s.io/client-go/tools/cache"
)
//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store
//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows
type testStoreObject struct {
Id string
Val string
}
func TestNewIndexer(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or TXClient, should return no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
objKey := "objKey"
indexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
storeName := "someStoreName"
store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil)
store.EXPECT().GetName().AnyTimes().Return(storeName)
client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil)
client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil)
client.EXPECT().Commit().Return(nil)
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName))
store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName))
store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName))
store.EXPECT().Prepare(fmt.Sprintf(listKeyByIndexFmt, storeName))
store.EXPECT().Prepare(fmt.Sprintf(listIndexValuesFmt, storeName))
indexer, err := NewIndexer(indexers, store)
assert.Nil(t, err)
assert.Equal(t, cache.Indexers(indexers), indexer.indexers)
}})
tests = append(tests, testCase{description: "NewIndexer() with Store Begin() error, should return error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
objKey := "objKey"
indexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
store.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error"))
_, err := NewIndexer(indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on first call to Exec(), should return error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
objKey := "objKey"
indexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
storeName := "someStoreName"
store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil)
store.EXPECT().GetName().AnyTimes().Return(storeName)
client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(fmt.Errorf("error"))
_, err := NewIndexer(indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on second call to Exec(), should return error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
objKey := "objKey"
indexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
storeName := "someStoreName"
store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil)
store.EXPECT().GetName().AnyTimes().Return(storeName)
client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil)
client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(fmt.Errorf("error"))
_, err := NewIndexer(indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with TXClient Commit() error, should return error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
objKey := "objKey"
indexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
storeName := "someStoreName"
store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil)
store.EXPECT().GetName().AnyTimes().Return(storeName)
client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil)
client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil)
client.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewIndexer(indexers, store)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestAfterUpsert(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from TXClient should return no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
deleteStmt := &sql.Stmt{}
addStmt := &sql.Stmt{}
objKey := "key"
indexer := &Indexer{
Store: store,
deleteIndicesStmt: deleteStmt,
addIndexStmt: addStmt,
indexers: map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
key := "somekey"
client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt)
client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil)
client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt)
client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(nil)
testObject := testStoreObject{Id: "something", Val: "a"}
err := indexer.AfterUpsert(key, testObject, client)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient StmtExec() should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
deleteStmt := &sql.Stmt{}
addStmt := &sql.Stmt{}
objKey := "key"
indexer := &Indexer{
Store: store,
deleteIndicesStmt: deleteStmt,
addIndexStmt: addStmt,
indexers: map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
key := "somekey"
client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt)
client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(fmt.Errorf("error"))
testObject := testStoreObject{Id: "something", Val: "a"}
err := indexer.AfterUpsert(key, testObject, client)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient second StmtExec() call should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
client := NewMockTXClient(gomock.NewController(t))
deleteStmt := &sql.Stmt{}
addStmt := &sql.Stmt{}
objKey := "key"
indexer := &Indexer{
Store: store,
deleteIndicesStmt: deleteStmt,
addIndexStmt: addStmt,
indexers: map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
key := "somekey"
client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt)
client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil)
client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt)
client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(fmt.Errorf("error"))
testObject := testStoreObject{Id: "something", Val: "a"}
err := indexer.AfterUpsert(key, testObject, client)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestIndex(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Index() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
objs, err := indexer.Index(indexName, testObject)
assert.Nil(t, err)
assert.Equal(t, []any{testObject}, objs)
}})
tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil)
objs, err := indexer.Index(indexName, testObject)
assert.Nil(t, err)
assert.Equal(t, []any{testObject, testObject}, objs)
}})
tests = append(tests, testCase{description: "Index() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil)
objs, err := indexer.Index(indexName, testObject)
assert.Nil(t, err)
assert.Equal(t, []any{}, objs)
}})
tests = append(tests, testCase{description: "Index() where index name is not in indexers, should return error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
_, err := indexer.Index("someotherindexname", testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Index() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
_, err := indexer.Index(indexName, testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Index() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error"))
_, err := indexer.Index(indexName, testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple keys returned from index func, should return one obj and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
indexName: func(obj interface{}) ([]string, error) {
return []string{objKey, objKey + "2"}, nil
},
},
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().GetName().Return("name")
stmt := &sql.Stmt{}
store.EXPECT().Prepare(fmt.Sprintf(selectQueryFmt, "name", ", ?")).Return(stmt)
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
store.EXPECT().CloseStmt(stmt).Return(nil)
objs, err := indexer.Index(indexName, testObject)
assert.Nil(t, err)
assert.Equal(t, []any{testObject}, objs)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestByIndex(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
objs, err := indexer.ByIndex(indexName, objKey)
assert.Nil(t, err)
assert.Equal(t, []any{testObject}, objs)
}})
tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil)
objs, err := indexer.ByIndex(indexName, objKey)
assert.Nil(t, err)
assert.Equal(t, []any{testObject, testObject}, objs)
}})
tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil)
objs, err := indexer.ByIndex(indexName, objKey)
assert.Nil(t, err)
assert.Equal(t, []any{}, objs)
}})
tests = append(tests, testCase{description: "IndexBy() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
_, err := indexer.ByIndex(indexName, objKey)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "IndexBy() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error"))
_, err := indexer.ByIndex(indexName, objKey)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestListIndexFuncValues(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "ListIndexFuncvalues() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, nil)
vals := indexer.ListIndexFuncValues(indexName)
assert.Equal(t, []string{"somestrings"}, vals)
}})
tests = append(tests, testCase{description: "ListIndexFuncvalues() with QueryForRows() error returned from store, should panic", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error"))
assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) })
}})
tests = append(tests, testCase{description: "ListIndexFuncvalues() with ReadStrings() error returned from store, should panic", test: func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
rows := &sql.Rows{}
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, fmt.Errorf("error"))
assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) })
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestGetIndexers(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) {
objKey := "key"
expectedIndexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
indexer := &Indexer{
indexers: expectedIndexers,
}
indexers := indexer.GetIndexers()
assert.Equal(t, cache.Indexers(expectedIndexers), indexers)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestAddIndexers(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) {
objKey := "key"
expectedIndexers := map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
return []string{objKey}, nil
},
}
indexer := &Indexer{}
err := indexer.AddIndexers(expectedIndexers)
assert.Nil(t, err)
assert.ObjectsAreEqual(cache.Indexers(expectedIndexers), indexer.indexers)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

View File

@ -0,0 +1,94 @@
/*
package sql provides an Informer and Indexer that uses SQLite as a store, instead of an in-memory store like a map.
*/
package informer
import (
"context"
"time"
"github.com/rancher/steve/pkg/sqlcache/partition"
sqlStore "github.com/rancher/steve/pkg/sqlcache/store"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/tools/cache"
)
// Informer is a SQLite-backed cache.SharedIndexInformer that can execute queries on listprocessor structs
type Informer struct {
cache.SharedIndexInformer
ByOptionsLister
}
type ByOptionsLister interface {
ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error)
}
// this is set to a var so that it can be overridden by test code for mocking purposes
var newInformer = cache.NewSharedIndexInformer
// NewInformer returns a new SQLite-backed Informer for the type specified by schema in unstructured.Unstructured form
// using the specified client
func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*Informer, error) {
listWatcher := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
a, err := client.List(context.Background(), options)
return a, err
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
return client.Watch(context.Background(), options)
},
}
example := &unstructured.Unstructured{}
example.SetGroupVersionKind(gvk)
// avoids the informer to periodically resync (re-list) its resources
// currently it is a work hypothesis that, when interacting with the UI, this should not be needed
resyncPeriod := time.Duration(0)
sii := newInformer(listWatcher, example, resyncPeriod, cache.Indexers{})
if transform != nil {
if err := sii.SetTransform(transform); err != nil {
return nil, err
}
}
name := informerNameFromGVK(gvk)
s, err := sqlStore.NewStore(example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, shouldEncrypt, name)
if err != nil {
return nil, err
}
loi, err := NewListOptionIndexer(fields, s, namespaced)
if err != nil {
return nil, err
}
// HACK: replace the default informer's indexer with the SQL based one
UnsafeSet(sii, "indexer", loi)
return &Informer{
SharedIndexInformer: sii,
ByOptionsLister: loi,
}, nil
}
// ListByOptions returns objects according to the specified list options and partitions.
// Specifically:
// - an unstructured list of resources belonging to any of the specified partitions
// - the total number of resources (returned list might be a subset depending on pagination options in lo)
// - a continue token, if there are more pages after the returned one
// - an error instead of all of the above if anything went wrong
func (i *Informer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) {
return i.ByOptionsLister.ListByOptions(ctx, lo, partitions, namespace)
}
func informerNameFromGVK(gvk schema.GroupVersionKind) string {
return gvk.Group + "_" + gvk.Version + "_" + gvk.Kind
}

View File

@ -0,0 +1,59 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister
//
// Package informer is a generated GoMock package.
package informer
import (
context "context"
reflect "reflect"
partition "github.com/rancher/steve/pkg/sqlcache/partition"
gomock "go.uber.org/mock/gomock"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)
// MockByOptionsLister is a mock of ByOptionsLister interface.
type MockByOptionsLister struct {
ctrl *gomock.Controller
recorder *MockByOptionsListerMockRecorder
}
// MockByOptionsListerMockRecorder is the mock recorder for MockByOptionsLister.
type MockByOptionsListerMockRecorder struct {
mock *MockByOptionsLister
}
// NewMockByOptionsLister creates a new mock instance.
func NewMockByOptionsLister(ctrl *gomock.Controller) *MockByOptionsLister {
mock := &MockByOptionsLister{ctrl: ctrl}
mock.recorder = &MockByOptionsListerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockByOptionsLister) EXPECT() *MockByOptionsListerMockRecorder {
return m.recorder
}
// ListByOptions mocks base method.
func (m *MockByOptionsLister) ListByOptions(arg0 context.Context, arg1 ListOptions, arg2 []partition.Partition, arg3 string) (*unstructured.UnstructuredList, int, string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListByOptions", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*unstructured.UnstructuredList)
ret1, _ := ret[1].(int)
ret2, _ := ret[2].(string)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// ListByOptions indicates an expected call of ListByOptions.
func (mr *MockByOptionsListerMockRecorder) ListByOptions(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByOptions", reflect.TypeOf((*MockByOptionsLister)(nil).ListByOptions), arg0, arg1, arg2, arg3)
}

View File

@ -0,0 +1,345 @@
package informer
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/tools/cache"
)
//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister
//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface
//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient
func TestNewInformer(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewInformer() with no errors returned, should return no error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
txClient := NewMockTXClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
// NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore
// is tested in depth in its own package.
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes()
// NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
// NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(context.Background(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
informer, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true)
assert.Nil(t, err)
assert.NotNil(t, informer.ByOptionsLister)
assert.NotNil(t, informer.SharedIndexInformer)
}})
tests = append(tests, testCase{description: "NewInformer() with errors returned from NewStore(), should return an error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
txClient := NewMockTXClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
// NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore
// is tested in depth in its own package.
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewInformer() with errors returned from NewIndexer(), should return an error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
txClient := NewMockTXClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
// NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore
// is tested in depth in its own package.
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes()
// NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewInformer() with errors returned from NewListOptionIndexer(), should return an error", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
txClient := NewMockTXClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
// NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore
// is tested in depth in its own package.
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes()
// NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
// NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewInformer() with transform func", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
txClient := NewMockTXClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
mockInformer := mockInformer{}
testNewInformer := func(lw cache.ListerWatcher,
exampleObject runtime.Object,
defaultEventHandlerResyncPeriod time.Duration,
indexers cache.Indexers) cache.SharedIndexInformer {
return &mockInformer
}
newInformer = testNewInformer
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
// NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore
// is tested in depth in its own package.
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes()
// NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
// NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer
// is tested in depth in its own indexer_test.go
dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
transformFunc := func(input interface{}) (interface{}, error) {
return "someoutput", nil
}
informer, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true)
assert.Nil(t, err)
assert.NotNil(t, informer.ByOptionsLister)
assert.NotNil(t, informer.SharedIndexInformer)
assert.NotNil(t, mockInformer.transformFunc)
// we can't test func == func, so instead we check if the output was as expected
input := "someinput"
ouput, err := mockInformer.transformFunc(input)
assert.Nil(t, err)
outputStr, ok := ouput.(string)
assert.True(t, ok, "ouput from transform was expected to be a string")
assert.Equal(t, "someoutput", outputStr)
newInformer = cache.NewSharedIndexInformer
}})
tests = append(tests, testCase{description: "NewInformer() unable to set transform func", test: func(t *testing.T) {
dbClient := NewMockDBClient(gomock.NewController(t))
dynamicClient := NewMockResourceInterface(gomock.NewController(t))
mockInformer := mockInformer{
setTranformErr: fmt.Errorf("some error"),
}
testNewInformer := func(lw cache.ListerWatcher,
exampleObject runtime.Object,
defaultEventHandlerResyncPeriod time.Duration,
indexers cache.Indexers) cache.SharedIndexInformer {
return &mockInformer
}
newInformer = testNewInformer
fields := [][]string{{"something"}}
gvk := schema.GroupVersionKind{}
transformFunc := func(input interface{}) (interface{}, error) {
return "someoutput", nil
}
_, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true)
assert.Error(t, err)
newInformer = cache.NewSharedIndexInformer
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestInformerListByOptions(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "ListByOptions() with no errors returned, should return no error and return value from indexer's ListByOptions()", test: func(t *testing.T) {
indexer := NewMockByOptionsLister(gomock.NewController(t))
informer := &Informer{
ByOptionsLister: indexer,
}
lo := ListOptions{}
var partitions []partition.Partition
ns := "somens"
expectedList := &unstructured.UnstructuredList{
Object: map[string]interface{}{"s": 2},
Items: []unstructured.Unstructured{{
Object: map[string]interface{}{"s": 2},
}},
}
expectedTotal := len(expectedList.Items)
expectedContinueToken := "123"
indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(expectedList, expectedTotal, expectedContinueToken, nil)
list, total, continueToken, err := informer.ListByOptions(context.TODO(), lo, partitions, ns)
assert.Nil(t, err)
assert.Equal(t, expectedList, list)
assert.Equal(t, len(expectedList.Items), total)
assert.Equal(t, expectedContinueToken, continueToken)
}})
tests = append(tests, testCase{description: "ListByOptions() with indexer ListByOptions error, should return error", test: func(t *testing.T) {
indexer := NewMockByOptionsLister(gomock.NewController(t))
informer := &Informer{
ByOptionsLister: indexer,
}
lo := ListOptions{}
var partitions []partition.Partition
ns := "somens"
indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(nil, 0, "", fmt.Errorf("error"))
_, _, _, err := informer.ListByOptions(context.TODO(), lo, partitions, ns)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
// Note: SQLite based caching uses an Informer that unsafely sets the Indexer as the ability to set it is not present
// in client-go at the moment. Long term, we look forward contribute a patch to client-go to make that configurable.
// Until then, we are adding this canary test that will panic in case the indexer cannot be set.
func TestUnsafeSet(t *testing.T) {
listWatcher := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
return &unstructured.UnstructuredList{}, nil
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
return dummyWatch{}, nil
},
}
sii := cache.NewSharedIndexInformer(listWatcher, &unstructured.Unstructured{}, 0, cache.Indexers{})
// will panic if SharedIndexInformer stops having a *Indexer field called "indexer"
UnsafeSet(sii, "indexer", &Indexer{})
}
type dummyWatch struct{}
func (dummyWatch) Stop() {
}
func (dummyWatch) ResultChan() <-chan watch.Event {
result := make(chan watch.Event)
defer close(result)
return result
}
// mockInformer is a mock of cache.SharedIndexInformer. Unlike other types, we can't generate this using mockgen because we use a unsafeSet to replace the
// indexer field, which is a struct field. This won't exist on the mock, producing an error. So we need to implement our own mock which actually has this field.
type mockInformer struct {
transformFunc cache.TransformFunc
setTranformErr error
indexer cache.Indexer
}
func (m *mockInformer) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) {
return nil, nil
}
func (m *mockInformer) AddEventHandlerWithResyncPeriod(handler cache.ResourceEventHandler, resyncPeriod time.Duration) (cache.ResourceEventHandlerRegistration, error) {
return nil, nil
}
func (m *mockInformer) RemoveEventHandler(handle cache.ResourceEventHandlerRegistration) error {
return nil
}
func (m *mockInformer) GetStore() cache.Store { return nil }
func (m *mockInformer) GetController() cache.Controller { return nil }
func (m *mockInformer) Run(stopCh <-chan struct{}) {}
func (m *mockInformer) HasSynced() bool { return false }
func (m *mockInformer) LastSyncResourceVersion() string { return "" }
func (m *mockInformer) SetWatchErrorHandler(handler cache.WatchErrorHandler) error { return nil }
func (m *mockInformer) IsStopped() bool { return false }
func (m *mockInformer) AddIndexers(indexers cache.Indexers) error { return nil }
func (m *mockInformer) GetIndexer() cache.Indexer { return nil }
func (m *mockInformer) SetTransform(handler cache.TransformFunc) error {
m.transformFunc = handler
return m.setTranformErr
}

View File

@ -0,0 +1,551 @@
package informer
import (
"context"
"database/sql"
"encoding/gob"
"errors"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/client-go/tools/cache"
"github.com/rancher/steve/pkg/sqlcache/db"
"github.com/rancher/steve/pkg/sqlcache/partition"
)
// ListOptionIndexer extends Indexer by allowing queries based on ListOption
type ListOptionIndexer struct {
*Indexer
namespaced bool
indexedFields []string
addFieldQuery string
deleteFieldQuery string
addFieldStmt *sql.Stmt
deleteFieldStmt *sql.Stmt
}
var (
defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"}
defaultIndexNamespaced = "metadata.namespace"
subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[a-zA-Z./]+])|(\[[0-9]+])`)
ErrInvalidColumn = errors.New("supplied column is invalid")
)
const (
matchFmt = `%%%s%%`
strictMatchFmt = `%s`
createFieldsTableFmt = `CREATE TABLE "%s_fields" (
key TEXT NOT NULL PRIMARY KEY,
%s
)`
createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")`
failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items: %w"
)
// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK
// ListOptionIndexer is also able to satisfy ListOption queries on indexed (sub)fields
// Fields are specified as slices (eg. "metadata.resourceVersion" is ["metadata", "resourceVersion"])
func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOptionIndexer, error) {
// necessary in order to gob/ungob unstructured.Unstructured objects
gob.Register(map[string]interface{}{})
gob.Register([]interface{}{})
i, err := NewIndexer(cache.Indexers{}, s)
if err != nil {
return nil, err
}
var indexedFields []string
for _, f := range defaultIndexedFields {
indexedFields = append(indexedFields, f)
}
if namespaced {
indexedFields = append(indexedFields, defaultIndexNamespaced)
}
for _, f := range fields {
indexedFields = append(indexedFields, toColumnName(f))
}
l := &ListOptionIndexer{
Indexer: i,
namespaced: namespaced,
indexedFields: indexedFields,
}
l.RegisterAfterUpsert(l.afterUpsert)
l.RegisterAfterDelete(l.afterDelete)
columnDefs := make([]string, len(indexedFields))
for index, field := range indexedFields {
column := fmt.Sprintf(`"%s" TEXT`, field)
columnDefs[index] = column
}
tx, err := l.BeginTx(context.Background(), true)
if err != nil {
return nil, err
}
err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, db.Sanitize(i.GetName()), strings.Join(columnDefs, ", ")))
if err != nil {
return nil, err
}
columns := make([]string, len(indexedFields))
qmarks := make([]string, len(indexedFields))
setStatements := make([]string, len(indexedFields))
for index, field := range indexedFields {
// create index for field
err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, db.Sanitize(i.GetName()), field, db.Sanitize(i.GetName()), field))
if err != nil {
return nil, err
}
// format field into column for prepared statement
column := fmt.Sprintf(`"%s"`, field)
columns[index] = column
// add placeholder for column's value in prepared statement
qmarks[index] = "?"
// add formatted set statement for prepared statement
setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field)
setStatements[index] = setStatement
}
err = tx.Commit()
if err != nil {
return nil, err
}
l.addFieldQuery = fmt.Sprintf(
`INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`,
db.Sanitize(i.GetName()),
strings.Join(columns, ", "),
strings.Join(qmarks, ", "),
strings.Join(setStatements, ", "),
)
l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, db.Sanitize(i.GetName()))
l.addFieldStmt = l.Prepare(l.addFieldQuery)
l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery)
return l, nil
}
/* Core methods */
// afterUpsert saves sortable/filterable fields into tables
func (l *ListOptionIndexer) afterUpsert(key string, obj any, tx db.TXClient) error {
args := []any{key}
for _, field := range l.indexedFields {
value, err := getField(obj, field)
if err != nil {
logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err)
cErr := tx.Cancel()
if cErr != nil {
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err)
}
return err
}
switch typedValue := value.(type) {
case nil:
args = append(args, "")
case int, bool, string:
args = append(args, fmt.Sprint(typedValue))
case []string:
args = append(args, strings.Join(typedValue, "|"))
default:
err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value)
cErr := tx.Cancel()
if cErr != nil {
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2)
}
return err2
}
}
err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...)
if err != nil {
return &db.QueryError{QueryString: l.addFieldQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) afterDelete(key string, tx db.TXClient) error {
args := []any{key}
err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...)
if err != nil {
return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err}
}
return nil
}
// ListByOptions returns objects according to the specified list options and partitions.
// Specifically:
// - an unstructured list of resources belonging to any of the specified partitions
// - the total number of resources (returned list might be a subset depending on pagination options in lo)
// - a continue token, if there are more pages after the returned one
// - an error instead of all of the above if anything went wrong
func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) {
// 1- Intro: SELECT and JOIN clauses
query := fmt.Sprintf(`SELECT o.object, o.objectnonce, o.dekid FROM "%s" o`, db.Sanitize(l.GetName()))
query += "\n "
query += fmt.Sprintf(`JOIN "%s_fields" f ON o.key = f.key`, db.Sanitize(l.GetName()))
params := []any{}
// 2- Filtering: WHERE clauses (from lo.Filters)
whereClauses := []string{}
for _, orFilters := range lo.Filters {
orClause, orParams, err := l.buildORClauseFromFilters(orFilters)
if err != nil {
return nil, 0, "", err
}
if orClause == "" {
continue
}
whereClauses = append(whereClauses, orClause)
params = append(params, orParams...)
}
// WHERE clauses (from namespace)
if namespace != "" && namespace != "*" {
whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`))
params = append(params, namespace)
}
// WHERE clauses (from partitions and their corresponding parameters)
partitionClauses := []string{}
for _, partition := range partitions {
if partition.Passthrough {
// nothing to do, no extra filtering to apply by definition
} else {
singlePartitionClauses := []string{}
// filter by namespace
if partition.Namespace != "" && partition.Namespace != "*" {
singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.namespace" = ?`))
params = append(params, partition.Namespace)
}
// optionally filter by names
if !partition.All {
names := partition.Names
if len(names) == 0 {
// degenerate case, there will be no results
singlePartitionClauses = append(singlePartitionClauses, "FALSE")
} else {
singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.name" IN (?%s)`, strings.Repeat(", ?", len(partition.Names)-1)))
// sort for reproducibility
sortedNames := partition.Names.UnsortedList()
sort.Strings(sortedNames)
for _, name := range sortedNames {
params = append(params, name)
}
}
}
if len(singlePartitionClauses) > 0 {
partitionClauses = append(partitionClauses, strings.Join(singlePartitionClauses, " AND "))
}
}
}
if len(partitions) == 0 {
// degenerate case, there will be no results
whereClauses = append(whereClauses, "FALSE")
}
if len(partitionClauses) == 1 {
whereClauses = append(whereClauses, partitionClauses[0])
}
if len(partitionClauses) > 1 {
whereClauses = append(whereClauses, "(\n ("+strings.Join(partitionClauses, ") OR\n (")+")\n)")
}
if len(whereClauses) > 0 {
query += "\n WHERE\n "
for index, clause := range whereClauses {
query += fmt.Sprintf("(%s)", clause)
if index == len(whereClauses)-1 {
break
}
query += " AND\n "
}
}
// 2- Sorting: ORDER BY clauses (from lo.Sort)
orderByClauses := []string{}
if len(lo.Sort.PrimaryField) > 0 {
columnName := toColumnName(lo.Sort.PrimaryField)
if err := l.validateColumn(columnName); err != nil {
return nil, 0, "", err
}
direction := "ASC"
if lo.Sort.PrimaryOrder == DESC {
direction = "DESC"
}
orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction))
}
if len(lo.Sort.SecondaryField) > 0 {
columnName := toColumnName(lo.Sort.SecondaryField)
if err := l.validateColumn(columnName); err != nil {
return nil, 0, "", err
}
direction := "ASC"
if lo.Sort.SecondaryOrder == DESC {
direction = "DESC"
}
orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction))
}
if len(orderByClauses) > 0 {
query += "\n ORDER BY "
query += strings.Join(orderByClauses, ", ")
} else {
// make sure one default order is always picked
if l.namespaced {
query += "\n ORDER BY f.\"metadata.namespace\" ASC, f.\"metadata.name\" ASC "
} else {
query += "\n ORDER BY f.\"metadata.name\" ASC "
}
}
// 4- Pagination: LIMIT clause (from lo.Pagination and/or lo.ChunkSize/lo.Resume)
// before proceeding, save a copy of the query and params without LIMIT/OFFSET
// for COUNTing all results later
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", query)
countParams := params[:]
limitClause := ""
// take the smallest limit between lo.Pagination and lo.ChunkSize
limit := lo.Pagination.PageSize
if limit == 0 || (lo.ChunkSize > 0 && lo.ChunkSize < limit) {
limit = lo.ChunkSize
}
if limit > 0 {
limitClause = "\n LIMIT ?"
params = append(params, limit)
}
// OFFSET clause (from lo.Pagination and/or lo.Resume)
offsetClause := ""
offset := 0
if lo.Resume != "" {
offsetInt, err := strconv.Atoi(lo.Resume)
if err != nil {
return nil, 0, "", err
}
offset = offsetInt
}
if lo.Pagination.Page >= 1 {
offset += lo.Pagination.PageSize * (lo.Pagination.Page - 1)
}
if offset > 0 {
offsetClause = "\n OFFSET ?"
params = append(params, offset)
}
// assemble and log the final query
query += limitClause
query += offsetClause
logrus.Debugf("ListOptionIndexer prepared statement: %v", query)
logrus.Debugf("Params: %v", params)
// execute
stmt := l.Prepare(query)
defer l.CloseStmt(stmt)
tx, err := l.BeginTx(ctx, false)
if err != nil {
return nil, 0, "", err
}
txStmt := tx.Stmt(stmt)
rows, err := txStmt.QueryContext(ctx, params...)
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", &db.QueryError{QueryString: query, Err: err}
}
items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt())
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", err
}
total := len(items)
// if limit or offset were set, execute counting of all rows
if limit > 0 || offset > 0 {
countStmt := l.Prepare(countQuery)
defer l.CloseStmt(countStmt)
txStmt := tx.Stmt(countStmt)
rows, err := txStmt.QueryContext(ctx, countParams...)
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", fmt.Errorf("error executing query: %w", err)
}
total, err = l.ReadInt(rows)
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", fmt.Errorf("error reading query results: %w", err)
}
}
if err := tx.Commit(); err != nil {
return nil, 0, "", err
}
continueToken := ""
if limit > 0 && offset+len(items) < total {
continueToken = fmt.Sprintf("%d", offset+limit)
}
return toUnstructuredList(items), total, continueToken, nil
}
func (l *ListOptionIndexer) validateColumn(column string) error {
for _, v := range l.indexedFields {
if v == column {
return nil
}
}
return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn)
}
// buildORClause creates an SQLite compatible query that ORs conditions built from passed filters
func (l *ListOptionIndexer) buildORClauseFromFilters(orFilters OrFilter) (string, []any, error) {
var orWhereClause string
var params []any
for index, filter := range orFilters.Filters {
opString := "LIKE"
if filter.Op == NotEq {
opString = "NOT LIKE"
}
columnName := toColumnName(filter.Field)
if err := l.validateColumn(columnName); err != nil {
return "", nil, err
}
orWhereClause += fmt.Sprintf(`f."%s" %s ? ESCAPE '\'`, columnName, opString)
format := strictMatchFmt
if filter.Partial {
format = matchFmt
}
match := filter.Match
// To allow matches on the backslash itself, the character needs to be replaced first.
// Otherwise, it will undo the following replacements.
match = strings.ReplaceAll(match, `\`, `\\`)
match = strings.ReplaceAll(match, `_`, `\_`)
match = strings.ReplaceAll(match, `%`, `\%`)
params = append(params, fmt.Sprintf(format, match))
if index == len(orFilters.Filters)-1 {
continue
}
orWhereClause += " OR "
}
return orWhereClause, params, nil
}
// toColumnName returns the column name corresponding to a field expressed as string slice
func toColumnName(s []string) string {
return db.Sanitize(strings.Join(s, "."))
}
// getField extracts the value of a field expressed as a string path from an unstructured object
func getField(a any, field string) (any, error) {
subFields := extractSubFields(field)
o, ok := a.(*unstructured.Unstructured)
if !ok {
return nil, fmt.Errorf("unexpected object type, expected unstructured.Unstructured: %v", a)
}
var obj interface{}
var found bool
var err error
obj = o.Object
for i, subField := range subFields {
switch t := obj.(type) {
case map[string]interface{}:
subField = strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]")
obj, found, err = unstructured.NestedFieldNoCopy(t, subField)
if err != nil {
return nil, err
}
if !found {
// particularly with labels/annotation indexes, it is totally possible that some objects won't have these,
// so either we this is not an error state or it could be an error state with a type that callers can check for
return nil, nil
}
case []interface{}:
if strings.HasPrefix(subField, "[") && strings.HasSuffix(subField, "]") {
key, err := strconv.Atoi(strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]"))
if err != nil {
return nil, fmt.Errorf("[listoption indexer] failed to convert subfield [%s] to int in listoption index: %w", subField, err)
}
if key >= len(t) {
return nil, fmt.Errorf("[listoption indexer] given index is too large for slice of len %d", len(t))
}
obj = fmt.Sprintf("%v", t[key])
} else if i == len(subFields)-1 {
result := make([]string, len(t))
for index, v := range t {
itemVal, ok := v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err)
}
itemStr, ok := itemVal[subField].(string)
if !ok {
return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err)
}
result[index] = itemStr
}
return result, nil
}
default:
return nil, fmt.Errorf("[listoption indexer] failed to parse subfields: %v", subFields)
}
}
return obj, nil
}
func extractSubFields(fields string) []string {
subfields := make([]string, 0)
for _, subField := range subfieldRegex.FindAllString(fields, -1) {
subfields = append(subfields, strings.TrimSuffix(subField, "."))
}
return subfields
}
// toUnstructuredList turns a slice of unstructured objects into an unstructured.UnstructuredList
func toUnstructuredList(items []any) *unstructured.UnstructuredList {
objectItems := make([]map[string]any, len(items))
result := &unstructured.UnstructuredList{
Items: make([]unstructured.Unstructured, len(items)),
Object: map[string]interface{}{"items": objectItems},
}
for i, item := range items {
result.Items[i] = *item.(*unstructured.Unstructured)
objectItems[i] = item.(*unstructured.Unstructured).Object
}
return result
}

View File

@ -0,0 +1,751 @@
/*
Copyright 2023 SUSE LLC
Adapted from client-go, Copyright 2014 The Kubernetes Authors.
*/
package informer
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/sets"
"github.com/rancher/steve/pkg/sqlcache/partition"
)
func TestNewListOptionIndexer(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewListOptionIndexer() with no errors returned, should return no error", test: func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
fields := [][]string{{"something"}}
id := "somename"
stmt := &sql.Stmt{}
// logic for NewIndexer(), only interested in if this results in error or not
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
store.EXPECT().GetName().Return(id).AnyTimes()
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes()
// end NewIndexer() logic
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().RegisterAfterDelete(gomock.Any())
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
// create field table
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil)
// create field table indexes
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil)
txClient.EXPECT().Commit().Return(nil)
loi, err := NewListOptionIndexer(fields, store, true)
assert.Nil(t, err)
assert.NotNil(t, loi)
}})
tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from NewIndxer(), should return an error", test: func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
fields := [][]string{{"something"}}
id := "somename"
// logic for NewIndexer(), only interested in if this results in error or not
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
store.EXPECT().GetName().Return(id).AnyTimes()
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewListOptionIndexer(fields, store, false)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from Begin(), should return an error", test: func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
fields := [][]string{{"something"}}
id := "somename"
stmt := &sql.Stmt{}
// logic for NewIndexer(), only interested in if this results in error or not
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
store.EXPECT().GetName().Return(id).AnyTimes()
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes()
// end NewIndexer() logic
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().RegisterAfterDelete(gomock.Any())
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, fmt.Errorf("error"))
_, err := NewListOptionIndexer(fields, store, false)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Exec() when creating fields table, should return an error", test: func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
fields := [][]string{{"something"}}
id := "somename"
stmt := &sql.Stmt{}
// logic for NewIndexer(), only interested in if this results in error or not
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
store.EXPECT().GetName().Return(id).AnyTimes()
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes()
// end NewIndexer() logic
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().RegisterAfterDelete(gomock.Any())
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(fmt.Errorf("error"))
_, err := NewListOptionIndexer(fields, store, true)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Commit(), should return an error", test: func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
fields := [][]string{{"something"}}
id := "somename"
stmt := &sql.Stmt{}
// logic for NewIndexer(), only interested in if this results in error or not
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
store.EXPECT().GetName().Return(id).AnyTimes()
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil)
txClient.EXPECT().Commit().Return(nil)
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes()
// end NewIndexer() logic
store.EXPECT().RegisterAfterUpsert(gomock.Any())
store.EXPECT().RegisterAfterDelete(gomock.Any())
store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil)
txClient.EXPECT().Commit().Return(fmt.Errorf("error"))
_, err := NewListOptionIndexer(fields, store, true)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestListByOptions(t *testing.T) {
type testCase struct {
description string
listOptions ListOptions
partitions []partition.Partition
ns string
expectedCountStmt string
expectedCountStmtArgs []any
expectedStmt string
expectedStmtArgs []any
expectedList *unstructured.UnstructuredList
returnList []any
expectedContToken string
expectedErr error
}
testObject := testStoreObject{Id: "something", Val: "a"}
unstrTestObjectMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&testObject)
assert.Nil(t, err)
// unstrTestObject
var tests []testCase
tests = append(tests, testCase{
description: "ListByOptions() with no errors returned, should not return an error",
listOptions: ListOptions{},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC `,
returnList: []any{},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions() with an empty filter, should not return an error",
listOptions: ListOptions{
Filters: []OrFilter{{[]Filter{}}},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with ChunkSize set should set limit in prepared sql.Stmt",
listOptions: ListOptions{ChunkSize: 2},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC
LIMIT ?`,
expectedStmtArgs: []interface{}{2},
expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC )`,
expectedCountStmtArgs: []interface{}{},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Resume set should set offset in prepared sql.Stmt",
listOptions: ListOptions{Resume: "4"},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC
OFFSET ?`,
expectedStmtArgs: []interface{}{4},
expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC )`,
expectedCountStmtArgs: []interface{}{},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with 1 OrFilter set with 1 filter should select where that filter is true in prepared sql.Stmt",
listOptions: ListOptions{Filters: []OrFilter{
{
[]Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somevalue",
},
},
},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.somefield" LIKE ? ESCAPE '\') AND
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"somevalue"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with 1 OrFilter set with 1 filter with Op set top NotEq should select where that filter is not true in prepared sql.Stmt",
listOptions: ListOptions{Filters: []OrFilter{
{
[]Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somevalue",
Op: NotEq,
},
},
},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"somevalue"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with 1 OrFilter set with 1 filter with Partial set to true should select where that partial match on that filter's value is true in prepared sql.Stmt",
listOptions: ListOptions{Filters: []OrFilter{
{
[]Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somevalue",
Partial: true,
},
},
},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.somefield" LIKE ? ESCAPE '\') AND
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"%somevalue%"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with 1 OrFilter set with multiple filters should select where any of those filters are true in prepared sql.Stmt",
listOptions: ListOptions{Filters: []OrFilter{
{
[]Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somevalue",
Partial: true,
},
{
Field: []string{"metadata", "somefield"},
Match: "someothervalue",
},
{
Field: []string{"metadata", "somefield"},
Match: "somethirdvalue",
Op: NotEq,
},
},
},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.somefield" LIKE ? ESCAPE '\' OR f."metadata.somefield" LIKE ? ESCAPE '\' OR f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"%somevalue%", "someothervalue", "somethirdvalue"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt",
listOptions: ListOptions{Filters: []OrFilter{
{
Filters: []Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somevalue",
Partial: true,
},
{
Field: []string{"status", "someotherfield"},
Match: "someothervalue",
Op: NotEq,
},
},
},
{
Filters: []Filter{
{
Field: []string{"metadata", "somefield"},
Match: "somethirdvalue",
Op: Eq,
},
},
},
},
},
partitions: []partition.Partition{},
ns: "test4",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.somefield" LIKE ? ESCAPE '\' OR f."status.someotherfield" NOT LIKE ? ESCAPE '\') AND
(f."metadata.somefield" LIKE ? ESCAPE '\') AND
(f."metadata.namespace" = ?) AND
(FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"%somevalue%", "someothervalue", "somethirdvalue", "test4"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.PrimaryField set only should sort on that field only, in ascending order in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
PrimaryField: []string{"metadata", "somefield"},
},
},
partitions: []partition.Partition{},
ns: "test5",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.namespace" = ?) AND
(FALSE)
ORDER BY f."metadata.somefield" ASC`,
expectedStmtArgs: []any{"test5"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.SecondaryField set only should sort on that field only, in ascending order in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
SecondaryField: []string{"metadata", "somefield"},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.somefield" ASC`,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.PrimaryField and Sort.SecondaryField set should sort on PrimaryField in ascending order first and then sort on SecondaryField in ascending order in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
PrimaryField: []string{"metadata", "somefield"},
SecondaryField: []string{"status", "someotherfield"},
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.somefield" ASC, f."status.someotherfield" ASC`,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.PrimaryField and Sort.SecondaryField set and PrimaryOrder set to DESC should sort on PrimaryField in descending order first and then sort on SecondaryField in ascending order in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
PrimaryField: []string{"metadata", "somefield"},
SecondaryField: []string{"status", "someotherfield"},
PrimaryOrder: DESC,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.SecondaryField set and Sort.PrimaryOrder set to descending should sort on that SecondaryField in ascending order only and ignore PrimaryOrder in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
SecondaryField: []string{"status", "someotherfield"},
PrimaryOrder: DESC,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."status.someotherfield" ASC`,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Sort.PrimaryOrder set only should sort on default primary and secondary fields in ascending order in prepared sql.Stmt",
listOptions: ListOptions{
Sort: Sort{
PrimaryOrder: DESC,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC `,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Pagination.PageSize set should set limit to PageSize in prepared sql.Stmt",
listOptions: ListOptions{
Pagination: Pagination{
PageSize: 10,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC
LIMIT ?`,
expectedStmtArgs: []any{10},
expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC )`,
expectedCountStmtArgs: []interface{}{},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Pagination.Page and no PageSize set should not add anything to prepared sql.Stmt",
listOptions: ListOptions{
Pagination: Pagination{
Page: 2,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC `,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with Pagination.Page and PageSize set limit to PageSize and offset to PageSize * (Page - 1) in prepared sql.Stmt",
listOptions: ListOptions{
Pagination: Pagination{
PageSize: 10,
Page: 2,
},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC
LIMIT ?
OFFSET ?`,
expectedStmtArgs: []any{10, 10},
expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC )`,
expectedCountStmtArgs: []interface{}{},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with a Namespace Partition should select only items where metadata.namespace is equal to Namespace and all other conditions are met in prepared sql.Stmt",
partitions: []partition.Partition{
{
Namespace: "somens",
},
},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.namespace" = ? AND FALSE)
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"somens"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with a All Partition should select all items that meet all other conditions in prepared sql.Stmt",
partitions: []partition.Partition{
{
All: true,
},
},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
ORDER BY f."metadata.name" ASC `,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with a Passthrough Partition should select all items that meet all other conditions prepared sql.Stmt",
partitions: []partition.Partition{
{
Passthrough: true,
},
},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
ORDER BY f."metadata.name" ASC `,
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
tests = append(tests, testCase{
description: "ListByOptions with a Names Partition should select only items where metadata.name equals an items in Names and all other conditions are met in prepared sql.Stmt",
partitions: []partition.Partition{
{
Names: sets.New[string]("someid", "someotherid"),
},
},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o
JOIN "something_fields" f ON o.key = f.key
WHERE
(f."metadata.name" IN (?, ?))
ORDER BY f."metadata.name" ASC `,
expectedStmtArgs: []any{"someid", "someotherid"},
returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}},
expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}},
expectedContToken: "",
expectedErr: nil,
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
txClient := NewMockTXClient(gomock.NewController(t))
store := NewMockStore(gomock.NewController(t))
stmts := NewMockStmt(gomock.NewController(t))
i := &Indexer{
Store: store,
}
lii := &ListOptionIndexer{
Indexer: i,
indexedFields: []string{"metadata.somefield", "status.someotherfield"},
}
stmt := &sql.Stmt{}
rows := &sql.Rows{}
objType := reflect.TypeOf(testObject)
store.EXPECT().BeginTx(gomock.Any(), false).Return(txClient, nil)
txClient.EXPECT().Stmt(gomock.Any()).Return(stmts).AnyTimes()
store.EXPECT().GetName().Return("something").AnyTimes()
store.EXPECT().Prepare(test.expectedStmt).Do(func(a ...any) {
fmt.Println(a)
}).Return(stmt)
if args := test.expectedStmtArgs; args != nil {
stmts.EXPECT().QueryContext(gomock.Any(), gomock.Any()).Return(rows, nil).AnyTimes()
} else if strings.Contains(test.expectedStmt, "LIMIT") {
stmts.EXPECT().QueryContext(gomock.Any(), args...).Return(rows, nil)
txClient.EXPECT().Stmt(gomock.Any()).Return(stmts)
stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil)
} else {
stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil)
}
store.EXPECT().GetType().Return(objType)
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, objType, false).Return(test.returnList, nil)
store.EXPECT().CloseStmt(stmt).Return(nil)
if test.expectedCountStmt != "" {
store.EXPECT().Prepare(test.expectedCountStmt).Return(stmt)
//store.EXPECT().QueryForRows(context.TODO(), stmt, test.expectedCountStmtArgs...).Return(rows, nil)
store.EXPECT().ReadInt(rows).Return(len(test.expectedList.Items), nil)
store.EXPECT().CloseStmt(stmt).Return(nil)
}
txClient.EXPECT().Commit()
list, total, contToken, err := lii.ListByOptions(context.TODO(), test.listOptions, test.partitions, test.ns)
if test.expectedErr == nil {
assert.Nil(t, err)
} else {
assert.Equal(t, test.expectedErr, err)
}
assert.Equal(t, test.expectedList, list)
assert.Equal(t, len(test.expectedList.Items), total)
assert.Equal(t, test.expectedContToken, contToken)
})
}
}

View File

@ -0,0 +1,59 @@
package informer
type Op string
const (
Eq Op = ""
NotEq Op = "!="
)
// SortOrder represents whether the list should be ascending or descending.
type SortOrder int
const (
// ASC stands for ascending order.
ASC SortOrder = iota
// DESC stands for descending (reverse) order.
DESC
)
// ListOptions represents the query parameters that may be included in a list request.
type ListOptions struct {
ChunkSize int
Resume string
Filters []OrFilter
Sort Sort
Pagination Pagination
}
// Filter represents a field to filter by.
// A subfield in an object is represented in a request query using . notation, e.g. 'metadata.name'.
// The subfield is internally represented as a slice, e.g. [metadata, name].
type Filter struct {
Field []string
Match string
Op Op
Partial bool
}
// OrFilter represents a set of possible fields to filter by, where an item may match any filter in the set to be included in the result.
type OrFilter struct {
Filters []Filter
}
// Sort represents the criteria to sort on.
// The subfield to sort by is represented in a request query using . notation, e.g. 'metadata.name'.
// The subfield is internally represented as a slice, e.g. [metadata, name].
// The order is represented by prefixing the sort key by '-', e.g. sort=-metadata.name.
type Sort struct {
PrimaryField []string
SecondaryField []string
PrimaryOrder SortOrder
SecondaryOrder SortOrder
}
// Pagination represents how to return paginated results.
type Pagination struct {
PageSize int
Page int
}

View File

@ -0,0 +1,22 @@
package informer
import (
"reflect"
"unsafe"
)
// UnsafeSet replaces the passed object's field value with the passed value.
func UnsafeSet(object any, field string, value any) {
rs := reflect.ValueOf(object).Elem()
rf := rs.FieldByName(field)
wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem()
wrf.Set(reflect.ValueOf(value))
}
// UnsafeGet returns the value of the passed object's for the passed field.
func UnsafeGet(object any, field string) any {
rs := reflect.ValueOf(object).Elem()
rf := rs.FieldByName(field)
wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem()
return wrf.Interface()
}

View File

@ -0,0 +1,325 @@
/*
Copyright 2023 SUSE LLC
Adapted from client-go, Copyright 2014 The Kubernetes Authors.
*/
package informer
import (
"fmt"
"k8s.io/client-go/tools/cache"
"strings"
"sync"
"testing"
"time"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
fcache "k8s.io/client-go/tools/cache/testing"
testingclock "k8s.io/utils/clock/testing"
)
type testListener struct {
lock sync.RWMutex
resyncPeriod time.Duration
expectedItemNames sets.Set[string]
receivedItemNames []string
name string
}
func newTestListener(name string, resyncPeriod time.Duration, expected ...string) *testListener {
l := &testListener{
resyncPeriod: resyncPeriod,
expectedItemNames: sets.New[string](expected...),
name: name,
}
return l
}
func (l *testListener) OnAdd(obj interface{}, isInInitialList bool) {
l.handle(obj)
}
func (l *testListener) OnUpdate(old, new interface{}) {
l.handle(new)
}
func (l *testListener) OnDelete(obj interface{}) {
}
func (l *testListener) handle(obj interface{}) {
key, _ := cache.MetaNamespaceKeyFunc(obj)
fmt.Printf("%s: handle: %v\n", l.name, key)
l.lock.Lock()
defer l.lock.Unlock()
objectMeta, _ := meta.Accessor(obj)
l.receivedItemNames = append(l.receivedItemNames, objectMeta.GetName())
}
func (l *testListener) ok() bool {
fmt.Println("polling")
err := wait.PollImmediate(100*time.Millisecond, 2*time.Second, func() (bool, error) {
if l.satisfiedExpectations() {
return true, nil
}
return false, nil
})
if err != nil {
return false
}
// wait just a bit to allow any unexpected stragglers to come in
fmt.Println("sleeping")
time.Sleep(1 * time.Second)
fmt.Println("final check")
return l.satisfiedExpectations()
}
func (l *testListener) satisfiedExpectations() bool {
l.lock.RLock()
defer l.lock.RUnlock()
return sets.New[string](l.receivedItemNames...).Equal(l.expectedItemNames)
}
func TestListenerResyncPeriods(t *testing.T) {
// source simulates an apiserver object endpoint.
source := fcache.NewFakeControllerSource()
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}})
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2"}})
// create the shared informer and resync every 1s
informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second)
clock := testingclock.NewFakeClock(time.Now())
UnsafeSet(informer, "clock", clock)
UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock)
// listener 1, never resync
listener1 := newTestListener("listener1", 0, "pod1", "pod2")
informer.AddEventHandlerWithResyncPeriod(listener1, listener1.resyncPeriod)
// listener 2, resync every 2s
listener2 := newTestListener("listener2", 2*time.Second, "pod1", "pod2")
informer.AddEventHandlerWithResyncPeriod(listener2, listener2.resyncPeriod)
// listener 3, resync every 3s
listener3 := newTestListener("listener3", 3*time.Second, "pod1", "pod2")
informer.AddEventHandlerWithResyncPeriod(listener3, listener3.resyncPeriod)
listeners := []*testListener{listener1, listener2, listener3}
stop := make(chan struct{})
defer close(stop)
go informer.Run(stop)
// ensure all listeners got the initial List
for _, listener := range listeners {
if !listener.ok() {
t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames)
}
}
// reset
for _, listener := range listeners {
listener.receivedItemNames = []string{}
}
// advance so listener2 gets a resync
clock.Step(2 * time.Second)
// make sure listener2 got the resync
if !listener2.ok() {
t.Errorf("%s: expected %v, got %v", listener2.name, listener2.expectedItemNames, listener2.receivedItemNames)
}
// wait a bit to give errant items a chance to go to 1 and 3
time.Sleep(1 * time.Second)
// make sure listeners 1 and 3 got nothing
if len(listener1.receivedItemNames) != 0 {
t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames))
}
if len(listener3.receivedItemNames) != 0 {
t.Errorf("listener3: should not have resynced (got %d)", len(listener3.receivedItemNames))
}
// reset
for _, listener := range listeners {
listener.receivedItemNames = []string{}
}
// advance so listener3 gets a resync
clock.Step(1 * time.Second)
// make sure listener3 got the resync
if !listener3.ok() {
t.Errorf("%s: expected %v, got %v", listener3.name, listener3.expectedItemNames, listener3.receivedItemNames)
}
// wait a bit to give errant items a chance to go to 1 and 2
time.Sleep(1 * time.Second)
// make sure listeners 1 and 2 got nothing
if len(listener1.receivedItemNames) != 0 {
t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames))
}
if len(listener2.receivedItemNames) != 0 {
t.Errorf("listener2: should not have resynced (got %d)", len(listener2.receivedItemNames))
}
}
// verify that https://github.com/kubernetes/kubernetes/issues/59822 is fixed
func TestSharedInformerInitializationRace(t *testing.T) {
source := fcache.NewFakeControllerSource()
informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second)
listener := newTestListener("raceListener", 0)
stop := make(chan struct{})
go informer.AddEventHandlerWithResyncPeriod(listener, listener.resyncPeriod)
go informer.Run(stop)
close(stop)
}
// TestSharedInformerWatchDisruption simulates a watch that was closed
// with updates to the store during that time. We ensure that handlers with
// resync and no resync see the expected state.
func TestSharedInformerWatchDisruption(t *testing.T) {
// source simulates an apiserver object endpoint.
source := fcache.NewFakeControllerSource()
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}})
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}})
// create the shared informer and resync every 1s
informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second)
clock := testingclock.NewFakeClock(time.Now())
UnsafeSet(informer, "clock", clock)
UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock)
// listener, never resync
listenerNoResync := newTestListener("listenerNoResync", 0, "pod1", "pod2")
informer.AddEventHandlerWithResyncPeriod(listenerNoResync, listenerNoResync.resyncPeriod)
listenerResync := newTestListener("listenerResync", 1*time.Second, "pod1", "pod2")
informer.AddEventHandlerWithResyncPeriod(listenerResync, listenerResync.resyncPeriod)
listeners := []*testListener{listenerNoResync, listenerResync}
stop := make(chan struct{})
defer close(stop)
go informer.Run(stop)
for _, listener := range listeners {
if !listener.ok() {
t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames)
}
}
// Add pod3, bump pod2 but don't broadcast it, so that the change will be seen only on relist
source.AddDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod3", UID: "pod3", ResourceVersion: "3"}})
source.ModifyDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "4"}})
// Ensure that nobody saw any changes
for _, listener := range listeners {
if !listener.ok() {
t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames)
}
}
for _, listener := range listeners {
listener.receivedItemNames = []string{}
}
listenerNoResync.expectedItemNames = sets.New[string]("pod2", "pod3")
listenerResync.expectedItemNames = sets.New[string]("pod1", "pod2", "pod3")
// This calls shouldSync, which deletes noResync from the list of syncingListeners
clock.Step(1 * time.Second)
// Simulate a connection loss (or even just a too-old-watch)
source.ResetWatch()
// Wait long enough for the reflector to exit and the backoff function to start waiting
// on the fake clock, otherwise advancing the fake clock will have no effect.
// TODO: Make this deterministic by counting the number of waiters on FakeClock
time.Sleep(10 * time.Millisecond)
// Advance the clock to cause the backoff wait to expire.
clock.Step(1601 * time.Millisecond)
// Wait long enough for backoff to invoke ListWatch a second time and distribute events
// to listeners.
time.Sleep(10 * time.Millisecond)
for _, listener := range listeners {
if !listener.ok() {
t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames)
}
}
}
func TestSharedInformerErrorHandling(t *testing.T) {
source := fcache.NewFakeControllerSource()
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}})
source.ListError = fmt.Errorf("Access Denied")
informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second)
errCh := make(chan error)
_ = informer.SetWatchErrorHandler(func(_ *cache.Reflector, err error) {
errCh <- err
})
stop := make(chan struct{})
go informer.Run(stop)
select {
case err := <-errCh:
if !strings.Contains(err.Error(), "Access Denied") {
t.Errorf("Expected 'Access Denied' error. Actual: %v", err)
}
case <-time.After(time.Second):
t.Errorf("Timeout waiting for error handler call")
}
close(stop)
}
func TestSharedInformerTransformer(t *testing.T) {
// source simulates an apiserver object endpoint.
source := fcache.NewFakeControllerSource()
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}})
source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}})
informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second)
informer.SetTransform(func(obj interface{}) (interface{}, error) {
if pod, ok := obj.(*v1.Pod); ok {
name := pod.GetName()
if upper := strings.ToUpper(name); upper != name {
copied := pod.DeepCopyObject().(*v1.Pod)
copied.SetName(upper)
return copied, nil
}
}
return obj, nil
})
listenerTransformer := newTestListener("listenerTransformer", 0, "POD1", "POD2")
informer.AddEventHandler(listenerTransformer)
stop := make(chan struct{})
go informer.Run(stop)
defer close(stop)
if !listenerTransformer.ok() {
t.Errorf("%s: expected %v, got %v", listenerTransformer.name, listenerTransformer.expectedItemNames, listenerTransformer.receivedItemNames)
}
}

View File

@ -0,0 +1,347 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: Store)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store
//
// Package informer is a generated GoMock package.
package informer
import (
context "context"
sql "database/sql"
reflect "reflect"
db "github.com/rancher/steve/pkg/sqlcache/db"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockStore is a mock of Store interface.
type MockStore struct {
ctrl *gomock.Controller
recorder *MockStoreMockRecorder
}
// MockStoreMockRecorder is the mock recorder for MockStore.
type MockStoreMockRecorder struct {
mock *MockStore
}
// NewMockStore creates a new mock instance.
func NewMockStore(ctrl *gomock.Controller) *MockStore {
mock := &MockStore{ctrl: ctrl}
mock.recorder = &MockStoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStore) EXPECT() *MockStoreMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockStore) Add(arg0 any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Add", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Add indicates an expected call of Add.
func (mr *MockStoreMockRecorder) Add(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), arg0)
}
// BeginTx mocks base method.
func (m *MockStore) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(db.TXClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockStoreMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockStore)(nil).BeginTx), arg0, arg1)
}
// CloseStmt mocks base method.
func (m *MockStore) CloseStmt(arg0 db.Closable) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseStmt", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CloseStmt indicates an expected call of CloseStmt.
func (mr *MockStoreMockRecorder) CloseStmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockStore)(nil).CloseStmt), arg0)
}
// Delete mocks base method.
func (m *MockStore) Delete(arg0 any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockStoreMockRecorder) Delete(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), arg0)
}
// Get mocks base method.
func (m *MockStore) Get(arg0 any) (any, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(any)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// Get indicates an expected call of Get.
func (mr *MockStoreMockRecorder) Get(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), arg0)
}
// GetByKey mocks base method.
func (m *MockStore) GetByKey(arg0 string) (any, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetByKey", arg0)
ret0, _ := ret[0].(any)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetByKey indicates an expected call of GetByKey.
func (mr *MockStoreMockRecorder) GetByKey(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByKey", reflect.TypeOf((*MockStore)(nil).GetByKey), arg0)
}
// GetName mocks base method.
func (m *MockStore) GetName() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetName")
ret0, _ := ret[0].(string)
return ret0
}
// GetName indicates an expected call of GetName.
func (mr *MockStoreMockRecorder) GetName() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockStore)(nil).GetName))
}
// GetShouldEncrypt mocks base method.
func (m *MockStore) GetShouldEncrypt() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetShouldEncrypt")
ret0, _ := ret[0].(bool)
return ret0
}
// GetShouldEncrypt indicates an expected call of GetShouldEncrypt.
func (mr *MockStoreMockRecorder) GetShouldEncrypt() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShouldEncrypt", reflect.TypeOf((*MockStore)(nil).GetShouldEncrypt))
}
// GetType mocks base method.
func (m *MockStore) GetType() reflect.Type {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetType")
ret0, _ := ret[0].(reflect.Type)
return ret0
}
// GetType indicates an expected call of GetType.
func (mr *MockStoreMockRecorder) GetType() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetType", reflect.TypeOf((*MockStore)(nil).GetType))
}
// List mocks base method.
func (m *MockStore) List() []any {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "List")
ret0, _ := ret[0].([]any)
return ret0
}
// List indicates an expected call of List.
func (mr *MockStoreMockRecorder) List() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStore)(nil).List))
}
// ListKeys mocks base method.
func (m *MockStore) ListKeys() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListKeys")
ret0, _ := ret[0].([]string)
return ret0
}
// ListKeys indicates an expected call of ListKeys.
func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKeys", reflect.TypeOf((*MockStore)(nil).ListKeys))
}
// Prepare mocks base method.
func (m *MockStore) Prepare(arg0 string) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Prepare indicates an expected call of Prepare.
func (mr *MockStoreMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockStore)(nil).Prepare), arg0)
}
// QueryForRows mocks base method.
func (m *MockStore) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryForRows", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryForRows indicates an expected call of QueryForRows.
func (mr *MockStoreMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockStore)(nil).QueryForRows), varargs...)
}
// ReadInt mocks base method.
func (m *MockStore) ReadInt(arg0 db.Rows) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadInt", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadInt indicates an expected call of ReadInt.
func (mr *MockStoreMockRecorder) ReadInt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockStore)(nil).ReadInt), arg0)
}
// ReadObjects mocks base method.
func (m *MockStore) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2)
ret0, _ := ret[0].([]any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadObjects indicates an expected call of ReadObjects.
func (mr *MockStoreMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockStore)(nil).ReadObjects), arg0, arg1, arg2)
}
// ReadStrings mocks base method.
func (m *MockStore) ReadStrings(arg0 db.Rows) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadStrings", arg0)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadStrings indicates an expected call of ReadStrings.
func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockStore)(nil).ReadStrings), arg0)
}
// RegisterAfterDelete mocks base method.
func (m *MockStore) RegisterAfterDelete(arg0 func(string, db.TXClient) error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterDelete", arg0)
}
// RegisterAfterDelete indicates an expected call of RegisterAfterDelete.
func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0)
}
// RegisterAfterUpsert mocks base method.
func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, db.TXClient) error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterUpsert", arg0)
}
// RegisterAfterUpsert indicates an expected call of RegisterAfterUpsert.
func (mr *MockStoreMockRecorder) RegisterAfterUpsert(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpsert", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpsert), arg0)
}
// Replace mocks base method.
func (m *MockStore) Replace(arg0 []any, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Replace", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Replace indicates an expected call of Replace.
func (mr *MockStoreMockRecorder) Replace(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockStore)(nil).Replace), arg0, arg1)
}
// Resync mocks base method.
func (m *MockStore) Resync() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Resync")
ret0, _ := ret[0].(error)
return ret0
}
// Resync indicates an expected call of Resync.
func (mr *MockStoreMockRecorder) Resync() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resync", reflect.TypeOf((*MockStore)(nil).Resync))
}
// Update mocks base method.
func (m *MockStore) Update(arg0 any) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Update", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Update indicates an expected call of Update.
func (mr *MockStoreMockRecorder) Update(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), arg0)
}

View File

@ -0,0 +1,165 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient
//
// Package informer is a generated GoMock package.
package informer
import (
context "context"
sql "database/sql"
reflect "reflect"
db "github.com/rancher/steve/pkg/sqlcache/db"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockDBClient is a mock of DBClient interface.
type MockDBClient struct {
ctrl *gomock.Controller
recorder *MockDBClientMockRecorder
}
// MockDBClientMockRecorder is the mock recorder for MockDBClient.
type MockDBClientMockRecorder struct {
mock *MockDBClient
}
// NewMockDBClient creates a new mock instance.
func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient {
mock := &MockDBClient{ctrl: ctrl}
mock.recorder = &MockDBClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder {
return m.recorder
}
// BeginTx mocks base method.
func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(db.TXClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1)
}
// CloseStmt mocks base method.
func (m *MockDBClient) CloseStmt(arg0 db.Closable) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseStmt", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CloseStmt indicates an expected call of CloseStmt.
func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0)
}
// Prepare mocks base method.
func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Prepare indicates an expected call of Prepare.
func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0)
}
// QueryForRows mocks base method.
func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryForRows", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryForRows indicates an expected call of QueryForRows.
func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...)
}
// ReadInt mocks base method.
func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadInt", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadInt indicates an expected call of ReadInt.
func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0)
}
// ReadObjects mocks base method.
func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2)
ret0, _ := ret[0].([]any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadObjects indicates an expected call of ReadObjects.
func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2)
}
// ReadStrings mocks base method.
func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadStrings", arg0)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadStrings indicates an expected call of ReadStrings.
func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0)
}
// Upsert mocks base method.
func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
return ret0
}
// Upsert indicates an expected call of Upsert.
func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4)
}

View File

@ -0,0 +1,99 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt
//
// Package informer is a generated GoMock package.
package informer
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}

View File

@ -0,0 +1,365 @@
package sql
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/dynamic"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/envtest"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/informer/factory"
"github.com/rancher/steve/pkg/sqlcache/partition"
)
const testNamespace = "sql-test"
var defaultPartition = partition.Partition{
All: true,
}
type IntegrationSuite struct {
suite.Suite
testEnv envtest.Environment
clientset kubernetes.Clientset
restCfg rest.Config
}
func (i *IntegrationSuite) SetupSuite() {
i.testEnv = envtest.Environment{}
restCfg, err := i.testEnv.Start()
i.Require().NoError(err, "error when starting env test - this is likely because setup-envtest wasn't done. Check the README for more information")
i.restCfg = *restCfg
clientset, err := kubernetes.NewForConfig(restCfg)
i.Require().NoError(err)
i.clientset = *clientset
testNs := v1.Namespace{
ObjectMeta: metav1.ObjectMeta{
Name: testNamespace,
},
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, err = i.clientset.CoreV1().Namespaces().Create(ctx, &testNs, metav1.CreateOptions{})
i.Require().NoError(err)
}
func (i *IntegrationSuite) TearDownSuite() {
err := i.testEnv.Stop()
i.Require().NoError(err)
}
func (i *IntegrationSuite) TestSQLCacheFilters() {
fields := [][]string{{`metadata`, `annotations[somekey]`}}
require := i.Require()
configMapWithAnnotations := func(name string, annotations map[string]string) v1.ConfigMap {
return v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: testNamespace,
Annotations: annotations,
},
}
}
createConfigMaps := func(configMaps ...v1.ConfigMap) {
for _, configMap := range configMaps {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
configMapClient := i.clientset.CoreV1().ConfigMaps(testNamespace)
_, err := configMapClient.Create(ctx, &configMap, metav1.CreateOptions{})
require.NoError(err)
// avoiding defer in a for loop
cancel()
}
}
// we create some configmaps before the cache starts and some after so that we can test the initial list and
// subsequent watches to make sure both work
// matches the filter for somekey == somevalue
matches := configMapWithAnnotations("matches-filter", map[string]string{"somekey": "somevalue"})
// partial match for somekey == somevalue (different suffix)
partialMatches := configMapWithAnnotations("partial-matches", map[string]string{"somekey": "somevaluehere"})
specialCharacterMatch := configMapWithAnnotations("special-character-matches", map[string]string{"somekey": "c%%l_value"})
backSlashCharacterMatch := configMapWithAnnotations("backslash-character-matches", map[string]string{"somekey": `my\windows\path`})
createConfigMaps(matches, partialMatches, specialCharacterMatch, backSlashCharacterMatch)
cache, cacheFactory, err := i.createCacheAndFactory(fields, nil)
require.NoError(err)
defer cacheFactory.Reset()
// doesn't match the filter for somekey == somevalue
notMatches := configMapWithAnnotations("not-matches-filter", map[string]string{"somekey": "notequal"})
// has no annotations, shouldn't match any filter
missing := configMapWithAnnotations("missing", nil)
createConfigMaps(notMatches, missing)
configMapNames := []string{matches.Name, partialMatches.Name, notMatches.Name, missing.Name, specialCharacterMatch.Name, backSlashCharacterMatch.Name}
err = i.waitForCacheReady(configMapNames, testNamespace, cache)
require.NoError(err)
orFiltersForFilters := func(filters ...informer.Filter) []informer.OrFilter {
return []informer.OrFilter{
{
Filters: filters,
},
}
}
tests := []struct {
name string
filters []informer.OrFilter
wantNames []string
}{
{
name: "matches filter",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.Eq,
Partial: false,
}),
wantNames: []string{"matches-filter"},
},
{
name: "partial matches filter",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.Eq,
Partial: true,
}),
wantNames: []string{"matches-filter", "partial-matches"},
},
{
name: "no matches for filter with underscore as it is interpreted literally",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalu_",
Op: informer.Eq,
Partial: true,
}),
wantNames: nil,
},
{
name: "no matches for filter with percent sign as it is interpreted literally",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalu%",
Op: informer.Eq,
Partial: true,
}),
wantNames: nil,
},
{
name: "match with special characters",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "c%%l_value",
Op: informer.Eq,
Partial: true,
}),
wantNames: []string{"special-character-matches"},
},
{
name: "match with literal backslash character",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: `my\windows\path`,
Op: informer.Eq,
Partial: true,
}),
wantNames: []string{"backslash-character-matches"},
},
{
name: "not eq filter",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.NotEq,
Partial: false,
}),
wantNames: []string{"partial-matches", "not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"},
},
{
name: "partial not eq filter",
filters: orFiltersForFilters(informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.NotEq,
Partial: true,
}),
wantNames: []string{"not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"},
},
{
name: "multiple or filters match",
filters: orFiltersForFilters(
informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.Eq,
Partial: true,
},
informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "notequal",
Op: informer.Eq,
Partial: false,
},
),
wantNames: []string{"matches-filter", "partial-matches", "not-matches-filter"},
},
{
name: "or filters on different fields",
filters: orFiltersForFilters(
informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.Eq,
Partial: true,
},
informer.Filter{
Field: []string{`metadata`, `name`},
Match: "missing",
Op: informer.Eq,
Partial: false,
},
),
wantNames: []string{"matches-filter", "partial-matches", "missing"},
},
{
name: "and filters, both must match",
filters: []informer.OrFilter{
{
Filters: []informer.Filter{
{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "somevalue",
Op: informer.Eq,
Partial: true,
},
},
},
{
Filters: []informer.Filter{
{
Field: []string{`metadata`, `name`},
Match: "matches-filter",
Op: informer.Eq,
Partial: false,
},
},
},
},
wantNames: []string{"matches-filter"},
},
{
name: "no matches",
filters: orFiltersForFilters(
informer.Filter{
Field: []string{`metadata`, `annotations[somekey]`},
Match: "valueNotRepresented",
Op: informer.Eq,
Partial: false,
},
),
wantNames: []string{},
},
}
for _, test := range tests {
test := test
i.Run(test.name, func() {
options := informer.ListOptions{
Filters: test.filters,
}
partitions := []partition.Partition{defaultPartition}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
cfgMaps, total, continueToken, err := cache.ListByOptions(ctx, options, partitions, testNamespace)
i.Require().NoError(err)
// since there's no additional pages, the continue token should be empty
i.Require().Equal("", continueToken)
i.Require().NotNil(cfgMaps)
// assert instead of require so that we can see the full evaluation of # of resources returned
i.Assert().Equal(len(test.wantNames), total)
i.Assert().Len(cfgMaps.Items, len(test.wantNames))
requireNames := sets.Set[string]{}
requireNames.Insert(test.wantNames...)
gotNames := sets.Set[string]{}
for _, configMap := range cfgMaps.Items {
gotNames.Insert(configMap.GetName())
}
i.Require().True(requireNames.Equal(gotNames), "wanted %v, got %v", requireNames, gotNames)
})
}
}
func (i *IntegrationSuite) createCacheAndFactory(fields [][]string, transformFunc cache.TransformFunc) (*factory.Cache, *factory.CacheFactory, error) {
cacheFactory, err := factory.NewCacheFactory()
if err != nil {
return nil, nil, fmt.Errorf("unable to make factory: %w", err)
}
dynamicClient, err := dynamic.NewForConfig(&i.restCfg)
if err != nil {
return nil, nil, fmt.Errorf("unable to make dynamicClient: %w", err)
}
configMapGVK := schema.GroupVersionKind{
Group: "",
Version: "v1",
Kind: "ConfigMap",
}
configMapGVR := schema.GroupVersionResource{
Group: "",
Version: "v1",
Resource: "configmaps",
}
dynamicResource := dynamicClient.Resource(configMapGVR).Namespace(testNamespace)
cache, err := cacheFactory.CacheFor(fields, transformFunc, dynamicResource, configMapGVK, true, true)
if err != nil {
return nil, nil, fmt.Errorf("unable to make cache: %w", err)
}
return &cache, cacheFactory, nil
}
func (i *IntegrationSuite) waitForCacheReady(readyResourceNames []string, namespace string, cache *factory.Cache) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
return wait.PollUntilContextCancel(ctx, time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) {
var options informer.ListOptions
partitions := []partition.Partition{defaultPartition}
cacheCtx, cacheCancel := context.WithTimeout(ctx, time.Second*5)
defer cacheCancel()
currentResources, total, _, err := cache.ListByOptions(cacheCtx, options, partitions, namespace)
if err != nil {
// note that we don't return the error since that would stop the polling
return false, nil
}
if total != len(readyResourceNames) {
return false, nil
}
wantNames := sets.Set[string]{}
wantNames.Insert(readyResourceNames...)
gotNames := sets.Set[string]{}
for _, current := range currentResources.Items {
name := current.GetName()
if !wantNames.Has(name) {
return true, fmt.Errorf("got resource %s which wasn't expected", name)
}
gotNames.Insert(name)
}
return wantNames.Equal(gotNames), nil
})
}
func TestIntegrationSuite(t *testing.T) {
suite.Run(t, new(IntegrationSuite))
}

View File

@ -0,0 +1,24 @@
/*
Package partition represents listing parameters. They can be used to specify which namespaces a caller would like included
in a response, or which specific objects they are looking for.
*/
package partition
import (
"k8s.io/apimachinery/pkg/util/sets"
)
// Partition represents filtering of a request's results
type Partition struct {
// if true, do not apply any filtering, return all results. Overrides all other fields
Passthrough bool
// if non-empty, only resources in the specified namespaces will be returned
Namespace string
// if true, return all results, while still honoring Namespace. Overrides Names
All bool
// if non-empty, only resources with matching names will be returned
Names sets.Set[string]
}

View File

@ -0,0 +1,204 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows
//
// Package store is a generated GoMock package.
package store
import (
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}
// MockRows is a mock of Rows interface.
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
}
// MockRowsMockRecorder is the mock recorder for MockRows.
type MockRowsMockRecorder struct {
mock *MockRows
}
// NewMockRows creates a new mock instance.
func NewMockRows(ctrl *gomock.Controller) *MockRows {
mock := &MockRows{ctrl: ctrl}
mock.recorder = &MockRowsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRows) EXPECT() *MockRowsMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRows) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRowsMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close))
}
// Err mocks base method.
func (m *MockRows) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err.
func (mr *MockRowsMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err))
}
// Next mocks base method.
func (m *MockRows) Next() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next")
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next.
func (mr *MockRowsMockRecorder) Next() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next))
}
// Scan mocks base method.
func (m *MockRows) Scan(arg0 ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Scan indicates an expected call of Scan.
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
}

351
pkg/sqlcache/store/store.go Normal file
View File

@ -0,0 +1,351 @@
/*
Package store contains the sql backed store. It persists objects to a sqlite database.
*/
package store
import (
"context"
"database/sql"
"fmt"
"reflect"
"github.com/rancher/steve/pkg/sqlcache/db"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
"k8s.io/client-go/tools/cache"
// needed for drivers
_ "modernc.org/sqlite"
)
const (
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
listKeysStmtFmt = `SELECT key FROM "%s"`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
key TEXT UNIQUE NOT NULL PRIMARY KEY,
object BLOB,
objectnonce BLOB,
dekid INTEGER
)`
)
// Store is a SQLite-backed cache.Store
type Store struct {
DBClient
name string
typ reflect.Type
keyFunc cache.KeyFunc
shouldEncrypt bool
upsertQuery string
deleteQuery string
getQuery string
listQuery string
listKeysQuery string
upsertStmt *sql.Stmt
deleteStmt *sql.Stmt
getStmt *sql.Stmt
listStmt *sql.Stmt
listKeysStmt *sql.Stmt
afterUpsert []func(key string, obj any, tx db.TXClient) error
afterDelete []func(key string, tx db.TXClient) error
}
// Test that Store implements cache.Indexer
var _ cache.Store = (*Store)(nil)
type DBClient interface {
BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error)
Prepare(stmt string) *sql.Stmt
QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error)
ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error)
ReadStrings(rows db.Rows) ([]string, error)
ReadInt(rows db.Rows) (int, error)
Upsert(tx db.TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error
CloseStmt(closable db.Closable) error
}
// NewStore creates a SQLite-backed cache.Store for objects of the given example type
func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool, name string) (*Store, error) {
s := &Store{
name: name,
typ: reflect.TypeOf(example),
DBClient: c,
keyFunc: keyFunc,
shouldEncrypt: shouldEncrypt,
afterUpsert: []func(key string, obj any, tx db.TXClient) error{},
afterDelete: []func(key string, tx db.TXClient) error{},
}
// once multiple informerfactories are needed, this can accept the case where table already exists error is received
txC, err := s.BeginTx(context.Background(), true)
if err != nil {
return nil, err
}
createTableQuery := fmt.Sprintf(createTableFmt, db.Sanitize(s.name))
err = txC.Exec(createTableQuery)
if err != nil {
return nil, &db.QueryError{QueryString: createTableQuery, Err: err}
}
err = txC.Commit()
if err != nil {
return nil, err
}
s.upsertQuery = fmt.Sprintf(upsertStmtFmt, db.Sanitize(s.name))
s.deleteQuery = fmt.Sprintf(deleteStmtFmt, db.Sanitize(s.name))
s.getQuery = fmt.Sprintf(getStmtFmt, db.Sanitize(s.name))
s.listQuery = fmt.Sprintf(listStmtFmt, db.Sanitize(s.name))
s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, db.Sanitize(s.name))
s.upsertStmt = s.Prepare(s.upsertQuery)
s.deleteStmt = s.Prepare(s.deleteQuery)
s.getStmt = s.Prepare(s.getQuery)
s.listStmt = s.Prepare(s.listQuery)
s.listKeysStmt = s.Prepare(s.listKeysQuery)
return s, nil
}
/* Core methods */
// upsert saves an obj with its key, or updates key with obj if it exists in this Store
func (s *Store) upsert(key string, obj any) error {
tx, err := s.BeginTx(context.Background(), true)
if err != nil {
return err
}
err = s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt)
if err != nil {
return &db.QueryError{QueryString: s.upsertQuery, Err: err}
}
err = s.runAfterUpsert(key, obj, tx)
if err != nil {
return err
}
return tx.Commit()
}
// deleteByKey deletes the object associated with key, if it exists in this Store
func (s *Store) deleteByKey(key string) error {
tx, err := s.BeginTx(context.Background(), true)
if err != nil {
return err
}
err = tx.StmtExec(tx.Stmt(s.deleteStmt), key)
if err != nil {
return &db.QueryError{QueryString: s.deleteQuery, Err: err}
}
err = s.runAfterDelete(key, tx)
if err != nil {
return err
}
return tx.Commit()
}
// GetByKey returns the object associated with the given object's key
func (s *Store) GetByKey(key string) (item any, exists bool, err error) {
rows, err := s.QueryForRows(context.TODO(), s.getStmt, key)
if err != nil {
return nil, false, &db.QueryError{QueryString: s.getQuery, Err: err}
}
result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt)
if err != nil {
return nil, false, err
}
if len(result) == 0 {
return nil, false, nil
}
return result[0], true, nil
}
/* Satisfy cache.Store */
// Add saves an obj, or updates it if it exists in this Store
func (s *Store) Add(obj any) error {
key, err := s.keyFunc(obj)
if err != nil {
return err
}
err = s.upsert(key, obj)
return err
}
// Update saves an obj, or updates it if it exists in this Store
func (s *Store) Update(obj any) error {
return s.Add(obj)
}
// Delete deletes the given object, if it exists in this Store
func (s *Store) Delete(obj any) error {
key, err := s.keyFunc(obj)
if err != nil {
return err
}
return s.deleteByKey(key)
}
// List returns a list of all the currently known objects
// Note: I/O errors will panic this function, as the interface signature does not allow returning errors
func (s *Store) List() []any {
rows, err := s.QueryForRows(context.TODO(), s.listStmt)
if err != nil {
panic(&db.QueryError{QueryString: s.listQuery, Err: err})
}
result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt)
if err != nil {
panic(fmt.Errorf("error in Store.List: %w", err))
}
return result
}
// ListKeys returns a list of all the keys currently in this Store
// Note: Atm it doesn't appear returning nil in the case of an error has any detrimental effects. An error is not
// uncommon enough nor does it appear to necessitate a panic.
func (s *Store) ListKeys() []string {
rows, err := s.QueryForRows(context.TODO(), s.listKeysStmt)
if err != nil {
fmt.Printf("Unexpected error in store.ListKeys: while executing query: %s got error: %v", s.listKeysQuery, err)
return []string{}
}
result, err := s.ReadStrings(rows)
if err != nil {
fmt.Printf("Unexpected error in store.ListKeys: %v\n", err)
return []string{}
}
return result
}
// Get returns the object with the same key as obj
func (s *Store) Get(obj any) (item any, exists bool, err error) {
key, err := s.keyFunc(obj)
if err != nil {
return nil, false, err
}
return s.GetByKey(key)
}
// Replace will delete the contents of the Store, using instead the given list
func (s *Store) Replace(objects []any, _ string) error {
objectMap := map[string]any{}
for _, object := range objects {
key, err := s.keyFunc(object)
if err != nil {
return err
}
objectMap[key] = object
}
return s.replaceByKey(objectMap)
}
// replaceByKey will delete the contents of the Store, using instead the given key to obj map
func (s *Store) replaceByKey(objects map[string]any) error {
txC, err := s.BeginTx(context.Background(), true)
if err != nil {
return err
}
txCListKeys := txC.Stmt(s.listKeysStmt)
rows, err := s.QueryForRows(context.TODO(), txCListKeys)
if err != nil {
return err
}
keys, err := s.ReadStrings(rows)
if err != nil {
return err
}
for _, key := range keys {
err = txC.StmtExec(txC.Stmt(s.deleteStmt), key)
if err != nil {
return err
}
err = s.runAfterDelete(key, txC)
if err != nil {
return err
}
}
for key, obj := range objects {
err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt)
if err != nil {
return err
}
err = s.runAfterUpsert(key, obj, txC)
if err != nil {
return err
}
}
return txC.Commit()
}
// Resync is a no-op and is deprecated
func (s *Store) Resync() error {
return nil
}
/* Utilities */
// RegisterAfterUpsert registers a func to be called after each upsert
func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC db.TXClient) error) {
s.afterUpsert = append(s.afterUpsert, f)
}
func (s *Store) GetName() string {
return s.name
}
func (s *Store) GetShouldEncrypt() bool {
return s.shouldEncrypt
}
func (s *Store) GetType() reflect.Type {
return s.typ
}
// keep
// runAfterUpsert executes functions registered to run after upsert
func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error {
for _, f := range s.afterUpsert {
err := f(key, obj, txC)
if err != nil {
return err
}
}
return nil
}
// RegisterAfterDelete registers a func to be called after each deletion
func (s *Store) RegisterAfterDelete(f func(key string, txC db.TXClient) error) {
s.afterDelete = append(s.afterDelete, f)
}
// keep
// runAfterDelete executes functions registered to run after upsert
func (s *Store) runAfterDelete(key string, txC db.TXClient) error {
for _, f := range s.afterDelete {
err := f(key, txC)
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,165 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient
//
// Package store is a generated GoMock package.
package store
import (
context "context"
sql "database/sql"
reflect "reflect"
db "github.com/rancher/steve/pkg/sqlcache/db"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockDBClient is a mock of DBClient interface.
type MockDBClient struct {
ctrl *gomock.Controller
recorder *MockDBClientMockRecorder
}
// MockDBClientMockRecorder is the mock recorder for MockDBClient.
type MockDBClientMockRecorder struct {
mock *MockDBClient
}
// NewMockDBClient creates a new mock instance.
func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient {
mock := &MockDBClient{ctrl: ctrl}
mock.recorder = &MockDBClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder {
return m.recorder
}
// BeginTx mocks base method.
func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(db.TXClient)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1)
}
// CloseStmt mocks base method.
func (m *MockDBClient) CloseStmt(arg0 db.Closable) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CloseStmt", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// CloseStmt indicates an expected call of CloseStmt.
func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0)
}
// Prepare mocks base method.
func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Prepare indicates an expected call of Prepare.
func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0)
}
// QueryForRows mocks base method.
func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryForRows", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryForRows indicates an expected call of QueryForRows.
func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...)
}
// ReadInt mocks base method.
func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadInt", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadInt indicates an expected call of ReadInt.
func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0)
}
// ReadObjects mocks base method.
func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2)
ret0, _ := ret[0].([]any)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadObjects indicates an expected call of ReadObjects.
func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2)
}
// ReadStrings mocks base method.
func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadStrings", arg0)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadStrings indicates an expected call of ReadStrings.
func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0)
}
// Upsert mocks base method.
func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
return ret0
}
// Upsert indicates an expected call of Upsert.
func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4)
}

View File

@ -0,0 +1,649 @@
/*
Copyright 2023 SUSE LLC
Adapted from client-go, Copyright 2014 The Kubernetes Authors.
*/
package store
// Mocks for this test are generated with the following command.
//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient
//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows
//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/rancher/steve/pkg/sqlcache/db"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
const ()
const TEST_DB_LOCATION = "./sqlstore.sqlite"
func testStoreKeyFunc(obj interface{}) (string, error) {
return obj.(testStoreObject).Id, nil
}
type testStoreObject struct {
Id string
Val string
}
func TestAdd(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
testObject := testStoreObject{Id: "something", Val: "a"}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Add with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
txC.EXPECT().Commit().Return(nil)
err := store.Add(testObject)
assert.Nil(t, err)
// dbclient beginerr
},
})
tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
txC.EXPECT().Commit().Return(nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
var count int
store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx db.TXClient) error {
count++
return nil
})
err := store.Add(testObject)
assert.Nil(t, err)
assert.Equal(t, count, 1)
},
})
tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error {
return fmt.Errorf("error")
})
err := store.Add(testObject)
assert.NotNil(t, err)
// dbclient beginerr
},
})
tests = append(tests, testCase{description: "Add with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed"))
store := SetupStore(t, c, shouldEncrypt)
err := store.Add(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Add with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed"))
err := store.Add(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Add with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed"))
err := store.Add(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Add with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
txC.EXPECT().Commit().Return(fmt.Errorf("failed"))
err := store.Add(testObject)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// Update updates the given object in the accumulator associated with the given object's key
func TestUpdate(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
testObject := testStoreObject{Id: "something", Val: "a"}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Update with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
txC.EXPECT().Commit().Return(nil)
err := store.Update(testObject)
assert.Nil(t, err)
// dbclient beginerr
},
})
tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
txC.EXPECT().Commit().Return(nil)
var count int
store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error {
count++
return nil
})
err := store.Update(testObject)
assert.Nil(t, err)
assert.Equal(t, count, 1)
},
})
tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error {
return fmt.Errorf("error")
})
err := store.Update(testObject)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Update with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed"))
store := SetupStore(t, c, shouldEncrypt)
err := store.Update(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Update with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed"))
err := store.Update(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Update with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed"))
err := store.Update(testObject)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Update with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil)
txC.EXPECT().Commit().Return(fmt.Errorf("failed"))
err := store.Update(testObject)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// Delete deletes the given object from the accumulator associated with the given object's key
func TestDelete(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
testObject := testStoreObject{Id: "something", Val: "a"}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Delete with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
// deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil)
txC.EXPECT().Commit().Return(nil)
err := store.Delete(testObject)
assert.Nil(t, err)
},
})
tests = append(tests, testCase{description: "Delete with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error"))
// deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt
err := store.Delete(testObject)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Delete with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error"))
// deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt
err := store.Delete(testObject)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Delete with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
// deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
// tx.EXPECT().
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil)
txC.EXPECT().Commit().Return(fmt.Errorf("error"))
err := store.Delete(testObject)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// List returns a list of all the currently non-empty accumulators
func TestList(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
testObject := testStoreObject{Id: "something", Val: "a"}
var tests []testCase
tests = append(tests, testCase{description: "List with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil)
items := store.List()
assert.Len(t, items, 0)
},
})
tests = append(tests, testCase{description: "List with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
fakeItemsToReturn := []any{"something1", 2, false}
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(fakeItemsToReturn, nil)
items := store.List()
assert.Equal(t, fakeItemsToReturn, items)
},
})
tests = append(tests, testCase{description: "List with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error"))
defer func() {
recover()
}()
_ = store.List()
assert.Fail(t, "Store list should panic when ReadObjects returns an error")
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// ListKeys returns a list of all the keys currently associated with non-empty accumulators
func TestListKeys(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
var tests []testCase
tests = append(tests, testCase{description: "ListKeys with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{"a", "b", "c"}, nil)
keys := store.ListKeys()
assert.Len(t, keys, 3)
},
})
tests = append(tests, testCase{description: "ListKeys with DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
keys := store.ListKeys()
assert.Len(t, keys, 0)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// Get returns the accumulator associated with the given object's key
func TestGet(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
tests = append(tests, testCase{description: "Get with no DB Client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil)
item, exists, err := store.Get(testObject)
assert.Nil(t, err)
assert.Equal(t, item, testObject)
assert.True(t, exists)
},
})
tests = append(tests, testCase{description: "Get with no DB Client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil)
item, exists, err := store.Get(testObject)
assert.Nil(t, err)
assert.Equal(t, item, nil)
assert.False(t, exists)
},
})
tests = append(tests, testCase{description: "Get with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error"))
_, _, err := store.Get(testObject)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// GetByKey returns the accumulator associated with the given key
func TestGetByKey(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil)
item, exists, err := store.GetByKey(testObject.Id)
assert.Nil(t, err)
assert.Equal(t, item, testObject)
assert.True(t, exists)
},
})
tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil)
item, exists, err := store.GetByKey(testObject.Id)
assert.Nil(t, err)
assert.Equal(t, nil, item)
assert.False(t, exists)
},
})
tests = append(tests, testCase{description: "GetByKey with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil)
c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error"))
_, _, err := store.GetByKey(testObject.Id)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// Replace will delete the contents of the store, using instead the
// given list. Store takes ownership of the list, you should not reference
// it after calling this function.
func TestReplace(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
tests = append(tests, testCase{description: "Replace with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id)
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
txC.EXPECT().Commit()
err := store.Replace([]any{testObject}, testObject.Id)
assert.Nil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) {
c, tx := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil)
tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{}, nil)
c.EXPECT().Upsert(tx, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
tx.EXPECT().Commit()
err := store.Replace([]any{testObject}, testObject.Id)
assert.Nil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with no DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) {
c, tx := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil)
tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) {
c, tx := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil)
tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt)
txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil)
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
// Resync is meaningless in the terms appearing here but has
// meaning in some implementations that have non-trivial
// additional behavior (e.g., DeltaFIFO).
func TestResync(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T, shouldEncrypt bool)
}
var tests []testCase
tests = append(tests, testCase{description: "Resync shouldn't call the client, panic, or do anything else", test: func(t *testing.T, shouldEncrypt bool) {
c, _ := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
err := store.Resync()
assert.Nil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t, false) })
t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) })
}
}
func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) {
dbC := NewMockDBClient(gomock.NewController(t)) // add functionality once store expectation are known
txC := NewMockTXClient(gomock.NewController(t))
// stmt := NewMockStmt(gomock.NewController())
txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil)
txC.EXPECT().Commit().Return(nil)
dbC.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil)
// use stmt mock here
dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
return dbC, txC
}
func SetupStore(t *testing.T, client *MockDBClient, shouldEncrypt bool) *Store {
store, err := NewStore(testStoreObject{}, testStoreKeyFunc, client, shouldEncrypt, "testStoreObject")
if err != nil {
t.Error(err)
}
return store
}

View File

@ -0,0 +1,99 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt
//
// Package store is a generated GoMock package.
package store
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}

View File

@ -10,8 +10,8 @@ import (
"github.com/rancher/apiserver/pkg/apierror"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/informer"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/wrangler/v3/pkg/schemas/validation"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

View File

@ -8,8 +8,8 @@ import (
"testing"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/informer"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"

View File

@ -13,8 +13,8 @@ import (
context "context"
reflect "reflect"
informer "github.com/rancher/lasso/pkg/cache/sql/informer"
partition "github.com/rancher/lasso/pkg/cache/sql/partition"
informer "github.com/rancher/steve/pkg/sqlcache/informer"
partition "github.com/rancher/steve/pkg/sqlcache/partition"
gomock "go.uber.org/mock/gomock"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

View File

@ -13,7 +13,7 @@ import (
reflect "reflect"
types "github.com/rancher/apiserver/pkg/types"
partition "github.com/rancher/lasso/pkg/cache/sql/partition"
partition "github.com/rancher/steve/pkg/sqlcache/partition"
gomock "go.uber.org/mock/gomock"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
watch "k8s.io/apimachinery/pkg/watch"

View File

@ -5,9 +5,9 @@ import (
"sort"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/accesscontrol"
"github.com/rancher/steve/pkg/attributes"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/wrangler/v3/pkg/kv"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"

View File

@ -6,7 +6,7 @@ import (
"go.uber.org/mock/gomock"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/steve/pkg/accesscontrol"
"github.com/rancher/wrangler/v3/pkg/schemas"
"github.com/stretchr/testify/assert"

View File

@ -7,14 +7,14 @@ import (
"context"
"github.com/rancher/apiserver/pkg/types"
lassopartition "github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/accesscontrol"
cachepartition "github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/steve/pkg/stores/partition"
)
// Partitioner is an interface for interacting with partitions.
type Partitioner interface {
All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]lassopartition.Partition, error)
All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]cachepartition.Partition, error)
Store() UnstructuredStore
}

View File

@ -15,7 +15,7 @@ import (
"go.uber.org/mock/gomock"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/steve/pkg/accesscontrol"
"github.com/rancher/steve/pkg/stores/sqlproxy"
"github.com/rancher/wrangler/v3/pkg/generic"

View File

@ -14,9 +14,9 @@ import (
reflect "reflect"
types "github.com/rancher/apiserver/pkg/types"
informer "github.com/rancher/lasso/pkg/cache/sql/informer"
factory "github.com/rancher/lasso/pkg/cache/sql/informer/factory"
partition "github.com/rancher/lasso/pkg/cache/sql/partition"
informer "github.com/rancher/steve/pkg/sqlcache/informer"
factory "github.com/rancher/steve/pkg/sqlcache/informer/factory"
partition "github.com/rancher/steve/pkg/sqlcache/partition"
summary "github.com/rancher/wrangler/v3/pkg/summary"
gomock "go.uber.org/mock/gomock"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"

View File

@ -31,9 +31,9 @@ import (
"github.com/rancher/apiserver/pkg/apierror"
"github.com/rancher/apiserver/pkg/types"
"github.com/rancher/lasso/pkg/cache/sql/informer"
"github.com/rancher/lasso/pkg/cache/sql/informer/factory"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/informer/factory"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/wrangler/v3/pkg/data"
"github.com/rancher/wrangler/v3/pkg/schemas"
"github.com/rancher/wrangler/v3/pkg/schemas/validation"
@ -333,7 +333,7 @@ func gvkKey(group, version, kind string) string {
return group + "_" + version + "_" + kind
}
// getFieldsFromSchema converts object field names from types.APISchema's format into lasso's
// getFieldsFromSchema converts object field names from types.APISchema's format into steve's
// cache.sql.informer's slice format (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"])
func getFieldsFromSchema(schema *types.APISchema) [][]string {
var fields [][]string
@ -757,7 +757,7 @@ func (s *Store) ListByPartitions(apiOp *types.APIRequest, schema *types.APISchem
list, total, continueToken, err := inf.ListByOptions(apiOp.Context(), opts, partitions, apiOp.Namespace)
if err != nil {
if errors.Is(err, informer.InvalidColumnErr) {
if errors.Is(err, informer.ErrInvalidColumn) {
return nil, 0, "", apierror.NewAPIError(validation.InvalidBodyContent, err.Error())
}
return nil, 0, "", err

View File

@ -12,11 +12,11 @@ import (
"github.com/rancher/wrangler/v3/pkg/schemas/validation"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"github.com/rancher/lasso/pkg/cache/sql/informer"
"github.com/rancher/lasso/pkg/cache/sql/informer/factory"
"github.com/rancher/lasso/pkg/cache/sql/partition"
"github.com/rancher/steve/pkg/attributes"
"github.com/rancher/steve/pkg/resources/common"
"github.com/rancher/steve/pkg/sqlcache/informer"
"github.com/rancher/steve/pkg/sqlcache/informer/factory"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/steve/pkg/stores/sqlpartition/listprocessor"
"github.com/rancher/steve/pkg/stores/sqlproxy/tablelistconvert"
"go.uber.org/mock/gomock"
@ -42,7 +42,7 @@ import (
)
//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./proxy_mocks_test.go github.com/rancher/steve/pkg/stores/sqlproxy Cache,ClientGetter,CacheFactory,SchemaColumnSetter,RelationshipNotifier,TransformBuilder
//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister
//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister
//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface
var c *watch.FakeWatcher

View File

@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister)
// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister
// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister
//
// Package sqlproxy is a generated GoMock package.
@ -13,8 +13,8 @@ import (
context "context"
reflect "reflect"
informer "github.com/rancher/lasso/pkg/cache/sql/informer"
partition "github.com/rancher/lasso/pkg/cache/sql/partition"
informer "github.com/rancher/steve/pkg/sqlcache/informer"
partition "github.com/rancher/steve/pkg/sqlcache/partition"
gomock "go.uber.org/mock/gomock"
unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
)

View File

@ -1,3 +1,13 @@
#!/bin/bash
go test ./...
if ! command -v setup-envtest; then
echo "setup-envtest is required for tests, but was not installed"
echo "see the 'Running Tests' section of the readme for install instructions"
exit 127
fi
minor=$(go mod graph | grep ' k8s.io/client-go@' | head -n1 | cut -d@ -f2 | cut -d '.' -f 2)
version="1.$minor.x"
export KUBEBUILDER_ASSETS=$(setup-envtest use -p path "$version")
go test ./...