diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b34d3f18..6b3275bd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,6 +25,8 @@ jobs: uses: golangci/golangci-lint-action@a4f60bb28d35aeee14e6880718e0c85ff1882e64 # v6.0.1 with: version: v1.63.4 + - name: Install env-test + run: go install sigs.k8s.io/controller-runtime/tools/setup-envtest@latest - name: Build run: make build-bin - name: Test diff --git a/go.mod b/go.mod index dedfd67f..708264fb 100644 --- a/go.mod +++ b/go.mod @@ -82,6 +82,9 @@ require ( k8s.io/klog v1.0.0 k8s.io/kube-aggregator v0.30.1 k8s.io/kube-openapi v0.0.0-20240411171206-dc4e619f62f3 + k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 + modernc.org/sqlite v1.29.10 + sigs.k8s.io/controller-runtime v0.19.0 ) require ( @@ -95,6 +98,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch v5.9.0+incompatible // indirect + github.com/evanphx/json-patch/v5 v5.9.0 // indirect github.com/felixge/httpsnoop v1.0.3 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -150,12 +154,10 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/component-base v0.30.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect - k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.49.3 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect - modernc.org/sqlite v1.29.10 // indirect modernc.org/strutil v1.2.0 // indirect modernc.org/token v1.1.0 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0 // indirect diff --git a/go.sum b/go.sum index 96cb26df..92c9ae0a 100644 --- a/go.sum +++ b/go.sum @@ -683,12 +683,16 @@ github.com/envoyproxy/protoc-gen-validate v0.10.0/go.mod h1:DRjgyB0I43LtJapqN6Ni github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls= github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= +github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -714,6 +718,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= +github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/jsonreference v0.20.1/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= @@ -740,6 +746,7 @@ github.com/golang/glog v1.1.0/go.mod h1:pfYeQZ3JWZoXTV5sFc986z3HTpwQs9At6P4ImfuP github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -1096,6 +1103,10 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -1133,6 +1144,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -1546,6 +1559,8 @@ golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= +gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= @@ -1965,6 +1980,8 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0 h1:/U5vjBbQn3RCh sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.29.0/go.mod h1:z7+wmGM2dfIiLRfrC6jb5kV2Mq/sK1ZP303cxzkV5Y4= sigs.k8s.io/cli-utils v0.37.2 h1:GOfKw5RV2HDQZDJlru5KkfLO1tbxqMoyn1IYUxqBpNg= sigs.k8s.io/cli-utils v0.37.2/go.mod h1:V+IZZr4UoGj7gMJXklWBg6t5xbdThFBcpj4MrZuCYco= +sigs.k8s.io/controller-runtime v0.18.5 h1:nTHio/W+Q4aBlQMgbnC5hZb4IjIidyrizMai9P6n4Rk= +sigs.k8s.io/controller-runtime v0.18.5/go.mod h1:TVoGrfdpbA9VRFaRnKgk9P5/atA0pMwq+f+msb9M8Sg= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= sigs.k8s.io/structured-merge-diff/v4 v4.2.3/go.mod h1:qjx8mGObPmV2aSZepjQjbmb2ihdVs8cGKBraizNC69E= diff --git a/pkg/server/server.go b/pkg/server/server.go index 472b7d5b..9027ef11 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -67,7 +67,7 @@ type Options struct { AggregationSecretName string ClusterRegistry string ServerVersion string - // SQLCache enables the SQLite-based lasso caching mechanism + // SQLCache enables the SQLite-based caching mechanism SQLCache bool } diff --git a/pkg/sqlcache/Readme.md b/pkg/sqlcache/Readme.md new file mode 100644 index 00000000..99da82a2 --- /dev/null +++ b/pkg/sqlcache/Readme.md @@ -0,0 +1,157 @@ +# SQL Cache + +## Sections +- [ListOptions Informer](#listoptions-informer) + - [List Options](#list-options) + - [ListOption Indexer](#listoptions-indexer) + - [SQL Store](#sql-store) + - [Partitions](#partitions) +- [How to Use](#how-to-use) +- [Technical Information](#technical-information) + - [SQL Tables](#sql-tables) + - [SQLite Driver](#sqlite-driver) + - [Connection Pooling](#connection-pooling) + - [Encryption Defaults](#encryption-defaults) + - [Indexed Fields](#indexed-fields) + - [ListOptions Behavior](#listoptions-behavior) + - [Troubleshooting Sqlite](#troubleshooting-sqlite) + + + +## ListOptions Informer +The main usable feature from the SQL cache is the ListOptions Informer. The ListOptionsInformer provides listing functionality, +like any other informer, but with a wider array of options. The options are configured by informer.ListOptions. + +### List Options +ListOptions includes the following: +* Match filters for indexed fields. Filters are for specifying the value a given field in an object should be in order to +be included in the list. Filters can be set to equals or not equals. Filters can be set to look for partial matches or +exact (strict) matches. Filters can be OR'd and AND'd with one another. Filters only work on fields that have been indexed. +* Primary field and secondary field sorting order. Can choose up to two fields to sort on. Sort order can be ascending +or descending. Default sorting is to sort on metadata.namespace in ascending first and then sort on metadata.name. +* Page size to specify how many items to include in a response. +* Page number to specify offset. For example, a page size of 50 and a page number of 2, will return items starting at +index 50. Index will be dependent on sort. Page numbers start at 1. + +### ListOptions Factory +The ListOptions Factory helps manage multiple ListOption Informers. A user can call Factory.InformerFor(), to create new +ListOptions informers if they do not exist and retrieve existing ones. + +### ListOptions Indexer +Like all other informers, the ListOptions informer uses an indexer to cache objects of the informer's type. A few features +set the ListOptions Indexer apart from others indexers: +* an on-disk store instead of an in-memory store. +* accepts list options backed by SQL queries for extended search/filter/sorting capability. +* AES GCM encryption using key hierarchy. + +### SQL Store +The SQL store is the main interface for interacting with the database. This store backs the indexer, and provides all +functionality required by the cache.Store interface. + +### Partitions +Partitions are constraints for ListOptionsInform ListByOptions() method that are separate from ListOptions. Partitions +are strict conditions that dictate which namespaces or names can be searched from. These overrule ListOptions and are +intended to be used as a way of enforcing RBAC. + +## How to Use +```go + package main + import( + "k8s.io/client-go/dynamic" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + ) + + func main() { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + panic(err) + } + // config should be some rest config created from kubeconfig + // there are other ways to create a config and any client that conforms to k8s.io/client-go/dynamic.ResourceInterface + // will work. + client, err := dynamic.NewForConfig(config) + if err != nil { + panic(err) + } + + fields := [][]string{{"metadata", "name"}, {"metadata", "namespace"}} + opts := &informer.ListOptions{} + // gvk should be of type k8s.io/apimachinery/pkg/runtime/schema.GroupVersionKind + c, err := cacheFactory.CacheFor(fields, client, gvk) + if err != nil { + panic(err) + } + + // continueToken will just be an offset that can be used in Resume on a subsequent request to continue + // to next page + list, continueToken, err := c.ListByOptions(apiOp.Context(), opts, partitions, namespace) + if err != nil { + panic(err) + } + } +``` + +## Technical Information + +### SQL Tables +There are three tables that are created for the ListOption informer: +* object table - this contains objects, including all their fields, as blobs. These blobs may be encrypted. +* fields table - this contains specific fields of value for objects. These are specified on informer create and are fields +that it is desired to filter or order on. +* indices table - the indices table stores indexes created and objects' values for each index. This backs the generic indexer +that contains the functionality needed to conform to cache.Indexer. + +### SQLite Driver +There are multiple SQLite drivers that this package could have used. One of the most, if not the most, popular SQLite golang +drivers is [mattn/go-sqlite3](https://github.com/mattn/go-sqlite3). This driver is not being used because it requires enabling +the cgo option when compiling and at the moment steve's main consumer, rancher, does not compile with cgo. We did not want +the SQL informer to be the sole driver in switching to using cgo. Instead, modernc's driver which is in pure golang. Side-by-side +comparisons can be found indicating the cgo version is, as expected, more performant. If in the future it is deemed worthwhile +then the driver can be easily switched by replacing the empty import in `pkg/cache/sql/store` from `_ "modernc.org/sqlite"` to `_ "github.com/mattn/go-sqlite3"`. + +### Connection Pooling +While working with the `database/sql` package for go, it is important to understand how sql.Open() and other methods manage +connections. Open starts a connection pool; that is to say after calling open once, there may be anywhere from zero to many +connections attached to a sql.Connection. `database/sql` manages this connection pool under the hood. In most cases, an +application only need one sql.Connection, although sometimes application use two: one for writes, the other for reads. To +read more about the `sql` package's connection pooling read [Managing connections](https://go.dev/doc/database/manage-connections). + +The use of connection pooling and the fact that steve potentially has many go routines accessing the same connection pool, +means we have to be careful with writes. Exclusively using sql transaction to write helps ensure safety. To read more about +sql transactions read SQLite's [Transaction docs](https://www.sqlite.org/lang_transaction.html). + +### Encryption Defaults +By default only specified types are encrypted. These types are hard-coded and defined by defaultEncryptedResourceTypes +in `pkg/cache/sql/informer/factory/informer_factory.go`. To enabled encryption for all types, set the ENV variable +`CATTLE_ENCRYPT_CACHE_ALL` to "true". + +The key size used is 256 bits. Data-encryption-keys are stored in the object table and are rotated every 150,000 writes. + +### Indexed Fields +Filtering and sorting only work on indexed fields. These fields are defined when using `CacheFor`. Objects will +have the following indexes by default: +* Fields in informer.defaultIndexedFields +* Fields passed to InformerFor() + +### ListOptions Behavior +Defaults: +* Sort.PrimaryField: `metadata.namespace` +* Sort.SecondaryField: `metadata.name` +* Sort.PrimaryOrder: `ASC` (ascending) +* Sort.SecondaryOrder: `ASC` (ascending) +* All filters have partial matching set to false by default + +There are some uncommon ways someone could use ListOptions where it would be difficult to predict what the result would be. +Below is a non-exhaustive list of some of these cases and what the behavior is: +* Setting Pagination.Page but not Pagination.PageSize will cause Page to be ignored +* Setting Sort.SecondaryField only will sort as though it was Sort.PrimaryField. Sort.SecondaryOrder will still be applied +and Sort.PrimaryOrder will be ignored + +### Writing Secure Queries +Values should be supplied to SQL queries using placeholders, read [Avoiding SQL Injection Risk](https://go.dev/doc/database/sql-injection). Any other portions +of a query that may be user supplied, such as columns, should be carefully validated against a fixed set of acceptable values. + +### Troubleshooting SQLite +A useful tool for troubleshooting the database files is the sqlite command line tool. Another useful tool is the goland +sqlite plugin. Both of these tools can be used with the database files. diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go new file mode 100644 index 00000000..62d81c3f --- /dev/null +++ b/pkg/sqlcache/db/client.go @@ -0,0 +1,374 @@ +/* +Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting, +and a way to reset the database. +*/ +package db + +import ( + "bytes" + "context" + "database/sql" + "encoding/gob" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + + "github.com/pkg/errors" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + + // needed for drivers + _ "modernc.org/sqlite" +) + +const ( + // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve + InformerObjectCacheDBPath = "informer_object_cache.db" + + informerObjectCachePerms fs.FileMode = 0o600 +) + +// Client is a database client that provides encrypting, decrypting, and database resetting. +type Client struct { + conn Connection + connLock sync.RWMutex + encryptor Encryptor + decryptor Decryptor +} + +// Connection represents a connection pool. +type Connection interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + Exec(query string, args ...any) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + Close() error +} + +// Closable Closes an underlying connection and returns an error on failure. +type Closable interface { + Close() error +} + +// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them. +type Rows interface { + Next() bool + Err() error + Close() error + Scan(dest ...any) error +} + +// QueryError encapsulates an error while executing a query +type QueryError struct { + QueryString string + Err error +} + +// Error returns a string representation of this QueryError +func (e *QueryError) Error() string { + return "while executing query: " + e.QueryString + " got error: " + e.Err.Error() +} + +// Unwrap returns the underlying error +func (e *QueryError) Unwrap() error { + return e.Err +} + +// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed. +type TXClient interface { + StmtExec(stmt transaction.Stmt, args ...any) error + Exec(stmt string, args ...any) error + Commit() error + Stmt(stmt *sql.Stmt) transaction.Stmt + Cancel() error +} + +// Encryptor encrypts data with a key which is rotated to avoid wear-out. +type Encryptor interface { + // Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. + Encrypt([]byte) ([]byte, []byte, uint32, error) +} + +// Decryptor decrypts data previously encrypted by Encryptor. +type Decryptor interface { + // Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. + Decrypt([]byte, []byte, uint32) ([]byte, error) +} + +// NewClient returns a Client. If the given connection is nil then a default one will be created. +func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) { + client := &Client{ + encryptor: encryptor, + decryptor: decryptor, + } + if c != nil { + client.conn = c + return client, nil + } + err := client.NewConnection() + if err != nil { + return nil, err + } + + return client, nil +} + +// Prepare prepares the given string into a sql statement on the client's connection. +func (c *Client) Prepare(stmt string) *sql.Stmt { + c.connLock.RLock() + defer c.connLock.RUnlock() + prepared, err := c.conn.Prepare(stmt) + if err != nil { + panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err)) + } + return prepared +} + +// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried +// given a sqlite busy error. +func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + return stmt.QueryContext(ctx, params...) +} + +// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant +// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection. +func (c *Client) CloseStmt(closable Closable) error { + return closable.Close() +} + +// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type, +// and returns a slice of those objects. +func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []any + for rows.Next() { + data, err := c.decryptScan(rows, shouldDecrypt) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + singleResult, err := fromBytes(data, typ) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + result = append(result, singleResult.Elem().Interface()) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadStrings scans the given rows into strings, and then returns the strings as a slice. +func (c *Client) ReadStrings(rows Rows) ([]string, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + var result []string + for rows.Next() { + var key string + err := rows.Scan(&key) + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + result = append(result, key) + } + err := rows.Err() + if err != nil { + return nil, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return nil, err + } + + return result, nil +} + +// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries) +func (c *Client) ReadInt(rows Rows) (int, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + + if !rows.Next() { + return 0, closeRowsOnError(rows, sql.ErrNoRows) + } + + var result int + err := rows.Scan(&result) + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Err() + if err != nil { + return 0, closeRowsOnError(rows, err) + } + + err = rows.Close() + if err != nil { + return 0, err + } + + return result, nil +} + +// BeginTx attempts to begin a transaction. +// If forWriting is true, this method blocks until all other concurrent forWriting +// transactions have either committed or rolled back. +// If forWriting is false, it is assumed the returned transaction will exclusively +// be used for DQL (eg. SELECT) queries. +// Not respecting the above rule might result in transactions failing with unexpected +// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked"). +// See discussion in https://github.com/rancher/lasso/pull/98 for details +func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) { + c.connLock.RLock() + defer c.connLock.RUnlock() + // note: this assumes _txlock=immediate in the connection string, see NewConnection + sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: !forWriting, + }) + if err != nil { + return nil, err + } + return transaction.NewClient(sqlTx), nil +} + +func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { + var data, dataNonce sql.RawBytes + var kid uint32 + err := rows.Scan(&data, &dataNonce, &kid) + if err != nil { + return nil, err + } + if c.decryptor != nil && shouldDecrypt { + decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid) + if err != nil { + return nil, err + } + return decryptedData, nil + } + return data, nil +} + +// Upsert used to be called upsertEncrypted in store package before move +func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error { + objBytes := toBytes(obj) + var dataNonce []byte + var err error + var kid uint32 + if c.encryptor != nil && shouldEncrypt { + objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes) + if err != nil { + return err + } + } + + return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid) +} + +// toBytes encodes an object to a byte slice +func toBytes(obj any) []byte { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + err := enc.Encode(obj) + if err != nil { + panic(fmt.Errorf("error while gobbing object: %w", err)) + } + bb := buf.Bytes() + return bb +} + +// fromBytes decodes an object from a byte slice +func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) { + dec := gob.NewDecoder(bytes.NewReader(buf)) + singleResult := reflect.New(typ) + err := dec.DecodeValue(singleResult) + return singleResult, err +} + +// closeRowsOnError closes the sql.Rows object and wraps errors if needed +func closeRowsOnError(rows Rows, err error) error { + ce := rows.Close() + if ce != nil { + return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce) + } + + return err +} + +// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently +// creates new files. +func (c *Client) NewConnection() error { + c.connLock.Lock() + defer c.connLock.Unlock() + if c.conn != nil { + err := c.conn.Close() + if err != nil { + return err + } + } + err := os.RemoveAll(InformerObjectCacheDBPath) + if err != nil { + return err + } + + // Set the permissions in advance, because we can't control them if + // the file is created by a sql.Open call instead. + if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil { + return nil + } + + sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+ + // open SQLite file in read-write mode, creating it if it does not exist + "mode=rwc&"+ + // use the WAL journal mode for consistency and efficiency + "_pragma=journal_mode=wal&"+ + // do not even attempt to attain durability. Database is thrown away at pod restart + "_pragma=synchronous=off&"+ + // do check foreign keys and honor ON DELETE CASCADE + "_pragma=foreign_keys=on&"+ + // if two transactions want to write at the same time, allow 2 minutes for the first to complete + // before baling out + "_pragma=busy_timeout=120000&"+ + // default to IMMEDIATE mode for transactions. Setting this parameter is the only current way + // to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation + // of BeginTx + "_txlock=immediate") + if err != nil { + return err + } + + c.conn = sqlDB + return nil +} + +// This acts like "touch" for both existing files and non-existing files. +// permissions. +// +// It's created with the correct perms, and if the file already exists, it will +// be chmodded to the correct perms. +func touchFile(filename string, perms fs.FileMode) error { + f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms) + if err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + + return os.Chmod(filename, perms) +} diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go new file mode 100644 index 00000000..8b7951f1 --- /dev/null +++ b/pkg/sqlcache/db/client_test.go @@ -0,0 +1,667 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "io/fs" + "math" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx + +type testStoreObject struct { + Id string + Val string +} + +func TestNewClient(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + expectedClient := &Client{ + conn: c, + encryptor: e, + decryptor: d, + } + client, err := NewClient(c, e, d) + assert.Nil(t, err) + assert.Equal(t, expectedClient, client) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryForRows(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + r := &sql.Rows{} + s.EXPECT().QueryContext(ctx).Return(r, nil) + rows, err := client.QueryForRows(ctx, s) + assert.Nil(t, err) + assert.Equal(t, r, rows) + }, + }) + tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + client := SetupClient(t, c, nil, nil) + s := NewMockStmt(gomock.NewController(t)) + ctx := context.TODO() + s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error")) + _, err := client.QueryForRows(ctx, s) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryObjects(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyId uint32 = math.MaxUint32 + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes( + testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + *a[0].(*sql.RawBytes) = toBytes(testObject) + *a[1].(*sql.RawBytes) = toBytes(testObject) + *a[2].(*uint32) = keyId + }) + d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestQueryStrings(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 1, len(items)) + }, + }) + tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + for _, v := range a { + vk := v.(*string) + *vk = string(toBytes(testObject.Id)) + } + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadStrings(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + items, err := client.ReadStrings(r) + assert.Nil(t, err) + assert.Equal(t, 0, len(items)) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestReadInt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testResult := 42 + tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + p := a[0].(*int) + *p = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + result, err := client.ReadInt(r) + assert.Nil(t, err) + assert.Equal(t, 42, result) + }, + }) + tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(fmt.Errorf("error")) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(true) + r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) { + a[0] = testResult + }) + r.EXPECT().Err().Return(nil) + r.EXPECT().Close().Return(fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + r := SetupMockRows(t) + r.EXPECT().Next().Return(false) + r.EXPECT().Close().Return(nil) + client := SetupClient(t, c, e, d) + _, err := client.ReadInt(r) + assert.ErrorIs(t, err, sql.ErrNoRows) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestBegin(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), false) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + sqlTx := &sql.Tx{} + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil) + client := SetupClient(t, c, e, d) + txC, err := client.BeginTx(context.Background(), true) + assert.Nil(t, err) + assert.NotNil(t, txC) + }, + }) + tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error")) + client := SetupClient(t, c, e, d) + _, err := client.BeginTx(context.Background(), false) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + testObject := testStoreObject{Id: "something", Val: "a"} + var keyID uint32 = 5 + + // Tests with shouldEncryptSet to true + tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + testObjBytes := toBytes(testObject) + e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + testObjBytes := toBytes(testObject) + testByteValue := []byte("something") + e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error")) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) { + c := SetupMockConnection(t) + d := SetupMockDecryptor(t) + e := SetupMockEncryptor(t) + + client := SetupClient(t, c, e, d) + txC := NewMockTXClient(gomock.NewController(t)) + sqlStmt := &sql.Stmt{} + stmt := NewMockStmt(gomock.NewController(t)) + var testByteValue []byte + testObjBytes := toBytes(testObject) + txC.EXPECT().Stmt(sqlStmt).Return(stmt) + txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil) + err := client.Upsert(txC, sqlStmt, "somekey", testObject, false) + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestPrepare(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + sqlStmt := &sql.Stmt{} + c.EXPECT().Prepare("something").Return(sqlStmt, nil) + + stmt := client.Prepare("something") + assert.Equal(t, sqlStmt, stmt) + }, + }) + tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error")) + + assert.Panics(t, func() { client.Prepare("something") }) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestNewConnection(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) { + c := SetupMockConnection(t) + e := SetupMockEncryptor(t) + d := SetupMockDecryptor(t) + + client := SetupClient(t, c, e, d) + c.EXPECT().Close().Return(nil) + + err := client.NewConnection() + assert.Nil(t, err) + + // Create a transaction to ensure that the file is written to disk. + txC, err := client.BeginTx(context.Background(), false) + assert.NoError(t, err) + assert.NoError(t, txC.Commit()) + + assert.FileExists(t, InformerObjectCacheDBPath) + assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600) + + err = os.Remove(InformerObjectCacheDBPath) + if err != nil { + assert.Fail(t, "could not remove object cache path after test") + } + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCommit(t *testing.T) { + +} + +func TestRollback(t *testing.T) { + +} + +func SetupMockConnection(t *testing.T) *MockConnection { + mockC := NewMockConnection(gomock.NewController(t)) + return mockC +} + +func SetupMockEncryptor(t *testing.T) *MockEncryptor { + mockE := NewMockEncryptor(gomock.NewController(t)) + return mockE +} + +func SetupMockDecryptor(t *testing.T) *MockDecryptor { + MockD := NewMockDecryptor(gomock.NewController(t)) + return MockD +} + +func SetupMockRows(t *testing.T) *MockRows { + MockR := NewMockRows(gomock.NewController(t)) + return MockR +} + +func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client { + c, _ := NewClient(connection, encryptor, decryptor) + return c +} + +func TestTouchFile(t *testing.T) { + t.Run("File doesn't exist before", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test1.txt") + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) + + t.Run("File exists with different permissions", func(t *testing.T) { + filename := filepath.Join(t.TempDir(), "test2.txt") + assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644)) + assert.NoError(t, touchFile(filename, 0600)) + assertFileHasPermissions(t, filename, 0600) + }) +} + +func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool { + t.Helper() + info, err := os.Lstat(fname) + if err != nil { + if os.IsNotExist(err) { + return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname)) + } + return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err)) + } + + // Stringifying the perms makes it easier to read than a Hex comparison. + assert.Equal(t, wantPerms.String(), info.Mode().Perm().String()) + + return true +} diff --git a/pkg/sqlcache/db/db_mocks_test.go b/pkg/sqlcache/db/db_mocks_test.go new file mode 100644 index 00000000..54199ba4 --- /dev/null +++ b/pkg/sqlcache/db/db_mocks_test.go @@ -0,0 +1,370 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} + +// MockConnection is a mock of Connection interface. +type MockConnection struct { + ctrl *gomock.Controller + recorder *MockConnectionMockRecorder +} + +// MockConnectionMockRecorder is the mock recorder for MockConnection. +type MockConnectionMockRecorder struct { + mock *MockConnection +} + +// NewMockConnection creates a new mock instance. +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(*sql.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1) +} + +// Close mocks base method. +func (m *MockConnection) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockConnectionMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close)) +} + +// Exec mocks base method. +func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...) +} + +// Prepare mocks base method. +func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0) +} + +// MockEncryptor is a mock of Encryptor interface. +type MockEncryptor struct { + ctrl *gomock.Controller + recorder *MockEncryptorMockRecorder +} + +// MockEncryptorMockRecorder is the mock recorder for MockEncryptor. +type MockEncryptorMockRecorder struct { + mock *MockEncryptor +} + +// NewMockEncryptor creates a new mock instance. +func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor { + mock := &MockEncryptor{ctrl: ctrl} + mock.recorder = &MockEncryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder { + return m.recorder +} + +// Encrypt mocks base method. +func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encrypt", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].([]byte) + ret2, _ := ret[2].(uint32) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Encrypt indicates an expected call of Encrypt. +func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0) +} + +// MockDecryptor is a mock of Decryptor interface. +type MockDecryptor struct { + ctrl *gomock.Controller + recorder *MockDecryptorMockRecorder +} + +// MockDecryptorMockRecorder is the mock recorder for MockDecryptor. +type MockDecryptorMockRecorder struct { + mock *MockDecryptor +} + +// NewMockDecryptor creates a new mock instance. +func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor { + mock := &MockDecryptor{ctrl: ctrl} + mock.recorder = &MockDecryptorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder { + return m.recorder +} + +// Decrypt mocks base method. +func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Decrypt indicates an expected call of Decrypt. +func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2) +} + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go new file mode 100644 index 00000000..136c6bd1 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -0,0 +1,90 @@ +/* +Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures. +The use of this package simplifies testing for callers by making the underlying transaction mock-able. +*/ +package transaction + +import ( + "context" + "database/sql" + + "github.com/pkg/errors" +) + +// Client provides a way to interact with the underlying sql transaction. +type Client struct { + sqlTx SQLTx +} + +// SQLTx represents a sql transaction +type SQLTx interface { + Exec(query string, args ...any) (sql.Result, error) + Stmt(stmt *sql.Stmt) *sql.Stmt + Commit() error + Rollback() error +} + +// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type +// because we are able to mock its outputs and do not need an actual connection. +type Stmt interface { + Exec(args ...any) (sql.Result, error) + Query(args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) +} + +// NewClient returns a Client with the given transaction assigned. +func NewClient(tx SQLTx) *Client { + return &Client{sqlTx: tx} +} + +// Commit commits the transaction and then unlocks the database. +func (c *Client) Commit() error { + return c.sqlTx.Commit() +} + +// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec() +// returns an error. +func (c *Client) Exec(stmt string, args ...any) error { + _, err := c.sqlTx.Exec(stmt, args...) + if err != nil { + return c.rollback(c.sqlTx, err) + } + return nil +} + +// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned +// here to aid in testing callers by providing a way to configure the statement's behavior. +func (c *Client) Stmt(stmt *sql.Stmt) Stmt { + s := c.sqlTx.Stmt(stmt) + return s +} + +// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The +// transaction is rolled back if Stmt.Exec() returns an error. +func (c *Client) StmtExec(stmt Stmt, args ...any) error { + _, err := stmt.Exec(args...) + if err != nil { + return c.rollback(c.sqlTx, err) + } + return nil +} + +// rollback handles rollbacks and wraps errors if needed +func (c *Client) rollback(tx SQLTx, err error) error { + rerr := tx.Rollback() + if rerr != nil { + return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr) + } + return errors.Wrapf(err, "Encountered error, successfully rolled back") +} + +// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned +// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too +// late. +func (c *Client) Cancel() error { + rerr := c.sqlTx.Rollback() + if rerr != sql.ErrTxDone { + return rerr + } + return nil +} diff --git a/pkg/sqlcache/db/transaction/transaction_mocks_test.go b/pkg/sqlcache/db/transaction/transaction_mocks_test.go new file mode 100644 index 00000000..0d7fdaa7 --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// + +// Package transaction is a generated GoMock package. +package transaction + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/transaction/transaction_test.go b/pkg/sqlcache/db/transaction/transaction_test.go new file mode 100644 index 00000000..0ede5d2e --- /dev/null +++ b/pkg/sqlcache/db/transaction/transaction_test.go @@ -0,0 +1,182 @@ +package transaction + +import ( + "database/sql" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx + +func TestNewClient(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + c := NewClient(tx) + assert.Equal(t, tx, c.sqlTx) +} + +func TestCommit(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + tx.EXPECT().Commit().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Commit() + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmtStr := "some statement %s" + arg := 5 + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error")) + c := &Client{ + sqlTx: tx, + } + err := c.Exec(stmtStr, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := &sql.Stmt{} + var returnedTXStmt *sql.Stmt + // should be passed same statement and arg that was passed to parent function + tx.EXPECT().Stmt(stmt).Return(returnedTXStmt) + c := &Client{ + sqlTx: tx, + } + returnedStmt := c.Stmt(stmt) + // whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it + // won't be equal to an empty struct + assert.Equal(t, returnedTXStmt, returnedStmt) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestStmtExec(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(nil) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) { + tx := NewMockSQLTx(gomock.NewController(t)) + stmt := NewMockStmt(gomock.NewController(t)) + arg := "something" + // should be passed same arg that was passed to parent function + stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) + tx.EXPECT().Rollback().Return(fmt.Errorf("error2")) + c := &Client{ + sqlTx: tx, + } + err := c.StmtExec(stmt, arg) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go new file mode 100644 index 00000000..1cac5caf --- /dev/null +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -0,0 +1,184 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// + +// Package db is a generated GoMock package. +package db + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} + +// MockSQLTx is a mock of SQLTx interface. +type MockSQLTx struct { + ctrl *gomock.Controller + recorder *MockSQLTxMockRecorder +} + +// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. +type MockSQLTxMockRecorder struct { + mock *MockSQLTx +} + +// NewMockSQLTx creates a new mock instance. +func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { + mock := &MockSQLTx{ctrl: ctrl} + mock.recorder = &MockSQLTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockSQLTx) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) +} + +// Rollback mocks base method. +func (m *MockSQLTx) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) +} + +// Stmt mocks base method. +func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/db/utility.go b/pkg/sqlcache/db/utility.go new file mode 100644 index 00000000..a8f84d29 --- /dev/null +++ b/pkg/sqlcache/db/utility.go @@ -0,0 +1,8 @@ +package db + +import "strings" + +// Sanitize returns a string that can be used in SQL as a name +func Sanitize(s string) string { + return strings.ReplaceAll(s, "\"", "") +} diff --git a/pkg/sqlcache/encryption/encrypt.go b/pkg/sqlcache/encryption/encrypt.go new file mode 100644 index 00000000..a7783ac9 --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt.go @@ -0,0 +1,168 @@ +/* +Package encryption provides encryption and decryption functions, while +abstracting away key management concerns. +Uses AES-GCM encryption, with key rotation, keeping keys in memory. +*/ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "sync" + + "github.com/pkg/errors" +) + +var ( + ErrKeyNotFound = errors.New("data key not found") + // maxWriteCount holds the maximum amount of times the active key can be + // used, prior to it being rotated. 2^32 is the currently recommended key + // wear-out params by NIST for AES-GCM using random nonces. + maxWriteCount int64 = 1 << 32 +) + +const ( + keySize = 32 // 32 for AES-256 +) + +// Manager uses AES-GCM encryption and keeps in memory the data encryption +// keys. The active encryption key is automatically rotated once it has been +// used over a certain amount of times - defined by maxWriteCount. +type Manager struct { + dataKeys [][]byte + activeKeyCounter int64 + + // lock works as the mutual exclusion lock for dataKeys. + lock sync.RWMutex + // counterLock works as the mutual exclusion lock for activeKeyCounter. + counterLock sync.Mutex +} + +// NewManager returns Manager, which satisfies db.Encryptor and db.Decryptor +func NewManager() (*Manager, error) { + m := &Manager{ + dataKeys: [][]byte{}, + } + m.newDataEncryptionKey() + + return m, nil +} + +// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. +func (m *Manager) Encrypt(data []byte) ([]byte, []byte, uint32, error) { + dek, keyID, err := m.fetchActiveDataKey() + if err != nil { + return nil, nil, 0, err + } + aead, err := createGCMCypher(dek) + if err != nil { + return nil, nil, 0, err + } + edata, nonce, err := encrypt(aead, data) + if err != nil { + return nil, nil, 0, err + } + return edata, nonce, keyID, nil +} + +// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error. +func (m *Manager) Decrypt(edata, nonce []byte, keyID uint32) ([]byte, error) { + dek, err := m.key(keyID) + if err != nil { + return nil, err + } + + aead, err := createGCMCypher(dek) + if err != nil { + return nil, errors.Wrap(err, "failed to create GCMCypher from DEK") + } + data, err := aead.Open(nil, nonce, edata, nil) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to decrypt data using keyid %d", keyID)) + } + return data, nil +} + +func encrypt(aead cipher.AEAD, data []byte) ([]byte, []byte, error) { + if aead == nil { + return nil, nil, fmt.Errorf("aead is nil, cannot encrypt data") + } + nonce := make([]byte, aead.NonceSize()) + _, err := rand.Read(nonce) + if err != nil { + return nil, nil, err + } + sealed := aead.Seal(nil, nonce, data, nil) + return sealed, nonce, nil +} + +func createGCMCypher(key []byte) (cipher.AEAD, error) { + b, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aead, err := cipher.NewGCM(b) + if err != nil { + return nil, err + } + return aead, nil +} + +// fetchActiveDataKey returns the current data key and its key ID. +// Each call results in activeKeyCounter being incremented by 1. When the +// the activeKeyCounter exceeds maxWriteCount, the active data key is +// rotated - before being returned. +func (m *Manager) fetchActiveDataKey() ([]byte, uint32, error) { + m.counterLock.Lock() + defer m.counterLock.Unlock() + + m.activeKeyCounter++ + if m.activeKeyCounter >= maxWriteCount { + return m.newDataEncryptionKey() + } + + return m.activeKey() +} + +func (m *Manager) newDataEncryptionKey() ([]byte, uint32, error) { + dek := make([]byte, keySize) + _, err := rand.Read(dek) + if err != nil { + return nil, 0, err + } + + m.lock.Lock() + defer m.lock.Unlock() + + m.activeKeyCounter = 1 + + m.dataKeys = append(m.dataKeys, dek) + keyID := uint32(len(m.dataKeys) - 1) + + return dek, keyID, nil +} + +func (m *Manager) activeKey() ([]byte, uint32, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + nk := len(m.dataKeys) + if nk == 0 { + return nil, 0, ErrKeyNotFound + } + keyID := uint32(nk - 1) + + return m.dataKeys[keyID], keyID, nil +} + +func (m *Manager) key(keyID uint32) ([]byte, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.dataKeys) <= int(keyID) { + return nil, fmt.Errorf("%w: %v", ErrKeyNotFound, keyID) + } + return m.dataKeys[keyID], nil +} diff --git a/pkg/sqlcache/encryption/encrypt_test.go b/pkg/sqlcache/encryption/encrypt_test.go new file mode 100644 index 00000000..46f3300a --- /dev/null +++ b/pkg/sqlcache/encryption/encrypt_test.go @@ -0,0 +1,327 @@ +package encryption + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewManager(t *testing.T) { + m, err := NewManager() + if err != nil { + t.FailNow() + } + assert.NotNil(t, m) +} + +func TestEncrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test encrypt with arbitrary initial key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt without arbitrary initial key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipherText, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + dek := m.dataKeys[keyID] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + decryptedData, err := aead.Open(nil, nonce, cipherText, nil) + + require.Nil(t, err) + assert.Equal(t, testData, decryptedData) + }}) + tests = append(tests, testCase{description: "test encrypt: same data yield different cipher/nonce pair", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + tests = append(tests, testCase{description: "test encrypt with key rotation", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + cipher1, nonce1, keyID1, err := m.Encrypt(testData) + require.Nil(t, err) + assert.Len(t, cipher1, 25) + assert.Len(t, nonce1, 12) + assert.NotEmpty(t, cipher1) + assert.NotEmpty(t, nonce1) + + m.activeKeyCounter += maxWriteCount + + cipher2, nonce2, keyID2, err := m.Encrypt(testData) + require.Nil(t, err) + + assert.Equal(t, int64(1), m.activeKeyCounter) + assert.NotEqual(t, keyID1, keyID2) + assert.NotEqual(t, cipher1, cipher2, "each encrypt op must return a unique cipher") + assert.NotEqual(t, nonce1, nonce2, "each encrypt op must return a unique nonce") + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestDecrypt(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + var tests []testCase + + tests = append(tests, testCase{description: "test decrypt with arbitrary key", test: func(t *testing.T) { + testDEK := []byte{83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181, 83, 125, 203, 18, 75, 156, 24, 192, 119, 73, 157, 222, 143, 140, 231, 181} + + m, err := NewManager() + require.Nil(t, err) + + m.dataKeys[0] = testDEK + + testData := []byte("something") + + // encrypt data out of band. + b, err := aes.NewCipher(testDEK) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt without arbitrary key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipherText, nonce, 0) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + tests = append(tests, testCase{description: "test decrypt with wrong data nonce should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + // generate random nonce. + randomNonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + // decrypted encrypted data using encrypted dek + _, err = m.Decrypt(cipherText, randomNonce, 0) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "test decrypt with DEK/nonce pair not used to encrypt should return error", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + // encrypt data out of band. + dek := m.dataKeys[0] + b, err := aes.NewCipher(dek) + require.Nil(t, err) + + aead, err := cipher.NewGCM(b) + require.Nil(t, err) + + nonce := make([]byte, aead.NonceSize()) + _, err = rand.Read(nonce) + require.Nil(t, err) + + cipherText := aead.Seal(nil, nonce, testData, nil) + + key, id, err := m.newDataEncryptionKey() + require.Nil(t, err) + m.dataKeys[id] = key + + plainText, err := m.Decrypt(cipherText, nonce, id) + assert.NotNil(t, err) + assert.Nil(t, plainText) + }, + }) + tests = append(tests, testCase{description: "test decrypt for non active key", test: func(t *testing.T) { + m, err := NewManager() + require.Nil(t, err) + + testData := []byte("something") + + cipher, nonce, keyID, err := m.Encrypt(testData) + require.Nil(t, err) + + // force key rotation. + m.activeKeyCounter += maxWriteCount + _, _, newKeyID, err := m.Encrypt(nil) + require.Nil(t, err) + require.NotEqual(t, keyID, newKeyID) + + // use manager to decrypt the data. + decryptedData, err := m.Decrypt(cipher, nonce, keyID) + require.Nil(t, err) + + assert.Equal(t, testData, decryptedData) + }, + }) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +var buf = make([]byte, 8192) + +func BenchmarkEncryption(b *testing.B) { + benchEncrypt(b, 1024) + benchEncrypt(b, 4096) + benchEncrypt(b, 8192) +} + +func BenchmarkDecryption(b *testing.B) { + benchDecrypt(b, 1024) + benchDecrypt(b, 4096) + benchDecrypt(b, 8192) +} + +func benchEncrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + // disable auto rotation to avoid skewing results. + maxWriteCount = math.MaxInt32 + + b.Run(fmt.Sprintf("encrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, _, _, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} + +func benchDecrypt(b *testing.B, size int) { + m, err := NewManager() + if err != nil { + b.Fatal("failed to create manager", err) + } + + edata, enonce, kid, err := m.Encrypt(buf[:size]) + if err != nil { + b.Fatal("failed to encrypt data", err) + } + + b.Run(fmt.Sprintf("decrypt-%d", size), func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + _, err := m.Decrypt(edata, enonce, kid) + if err != nil { + b.Fatal("error encrypting data", err) + } + } + }) +} diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go new file mode 100644 index 00000000..7d2c81ce --- /dev/null +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// + +// Package informer is a generated GoMock package. +package informer + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/informer/dynamic_mocks_test.go b/pkg/sqlcache/informer/dynamic_mocks_test.go new file mode 100644 index 00000000..07e169c1 --- /dev/null +++ b/pkg/sqlcache/informer/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go new file mode 100644 index 00000000..9ac55bb3 --- /dev/null +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -0,0 +1,121 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} diff --git a/pkg/sqlcache/informer/factory/dynamic_mocks_test.go b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go new file mode 100644 index 00000000..29e2c0fd --- /dev/null +++ b/pkg/sqlcache/informer/factory/dynamic_mocks_test.go @@ -0,0 +1,237 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/dynamic (interfaces: ResourceInterface) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + types "k8s.io/apimachinery/pkg/types" + watch "k8s.io/apimachinery/pkg/watch" +) + +// MockResourceInterface is a mock of ResourceInterface interface. +type MockResourceInterface struct { + ctrl *gomock.Controller + recorder *MockResourceInterfaceMockRecorder +} + +// MockResourceInterfaceMockRecorder is the mock recorder for MockResourceInterface. +type MockResourceInterfaceMockRecorder struct { + mock *MockResourceInterface +} + +// NewMockResourceInterface creates a new mock instance. +func NewMockResourceInterface(ctrl *gomock.Controller) *MockResourceInterface { + mock := &MockResourceInterface{ctrl: ctrl} + mock.recorder = &MockResourceInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceInterface) EXPECT() *MockResourceInterfaceMockRecorder { + return m.recorder +} + +// Apply mocks base method. +func (m *MockResourceInterface) Apply(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions, arg4 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Apply", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Apply indicates an expected call of Apply. +func (mr *MockResourceInterfaceMockRecorder) Apply(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Apply", reflect.TypeOf((*MockResourceInterface)(nil).Apply), varargs...) +} + +// ApplyStatus mocks base method. +func (m *MockResourceInterface) ApplyStatus(arg0 context.Context, arg1 string, arg2 *unstructured.Unstructured, arg3 v1.ApplyOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ApplyStatus", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ApplyStatus indicates an expected call of ApplyStatus. +func (mr *MockResourceInterfaceMockRecorder) ApplyStatus(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApplyStatus", reflect.TypeOf((*MockResourceInterface)(nil).ApplyStatus), arg0, arg1, arg2, arg3) +} + +// Create mocks base method. +func (m *MockResourceInterface) Create(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.CreateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Create", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockResourceInterfaceMockRecorder) Create(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockResourceInterface)(nil).Create), varargs...) +} + +// Delete mocks base method. +func (m *MockResourceInterface) Delete(arg0 context.Context, arg1 string, arg2 v1.DeleteOptions, arg3 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Delete", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockResourceInterfaceMockRecorder) Delete(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockResourceInterface)(nil).Delete), varargs...) +} + +// DeleteCollection mocks base method. +func (m *MockResourceInterface) DeleteCollection(arg0 context.Context, arg1 v1.DeleteOptions, arg2 v1.ListOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCollection", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCollection indicates an expected call of DeleteCollection. +func (mr *MockResourceInterfaceMockRecorder) DeleteCollection(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCollection", reflect.TypeOf((*MockResourceInterface)(nil).DeleteCollection), arg0, arg1, arg2) +} + +// Get mocks base method. +func (m *MockResourceInterface) Get(arg0 context.Context, arg1 string, arg2 v1.GetOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockResourceInterfaceMockRecorder) Get(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockResourceInterface)(nil).Get), varargs...) +} + +// List mocks base method. +func (m *MockResourceInterface) List(arg0 context.Context, arg1 v1.ListOptions) (*unstructured.UnstructuredList, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", arg0, arg1) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockResourceInterfaceMockRecorder) List(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockResourceInterface)(nil).List), arg0, arg1) +} + +// Patch mocks base method. +func (m *MockResourceInterface) Patch(arg0 context.Context, arg1 string, arg2 types.PatchType, arg3 []byte, arg4 v1.PatchOptions, arg5 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3, arg4} + for _, a := range arg5 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Patch", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Patch indicates an expected call of Patch. +func (mr *MockResourceInterfaceMockRecorder) Patch(arg0, arg1, arg2, arg3, arg4 any, arg5 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3, arg4}, arg5...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockResourceInterface)(nil).Patch), varargs...) +} + +// Update mocks base method. +func (m *MockResourceInterface) Update(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions, arg3 ...string) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2} + for _, a := range arg3 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Update", varargs...) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockResourceInterfaceMockRecorder) Update(arg0, arg1, arg2 any, arg3 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2}, arg3...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockResourceInterface)(nil).Update), varargs...) +} + +// UpdateStatus mocks base method. +func (m *MockResourceInterface) UpdateStatus(arg0 context.Context, arg1 *unstructured.Unstructured, arg2 v1.UpdateOptions) (*unstructured.Unstructured, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateStatus", arg0, arg1, arg2) + ret0, _ := ret[0].(*unstructured.Unstructured) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateStatus indicates an expected call of UpdateStatus. +func (mr *MockResourceInterfaceMockRecorder) UpdateStatus(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStatus", reflect.TypeOf((*MockResourceInterface)(nil).UpdateStatus), arg0, arg1, arg2) +} + +// Watch mocks base method. +func (m *MockResourceInterface) Watch(arg0 context.Context, arg1 v1.ListOptions) (watch.Interface, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Watch", arg0, arg1) + ret0, _ := ret[0].(watch.Interface) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockResourceInterfaceMockRecorder) Watch(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockResourceInterface)(nil).Watch), arg0, arg1) +} diff --git a/pkg/sqlcache/informer/factory/factory_mocks_test.go b/pkg/sqlcache/informer/factory/factory_mocks_test.go new file mode 100644 index 00000000..a7adab6a --- /dev/null +++ b/pkg/sqlcache/informer/factory/factory_mocks_test.go @@ -0,0 +1,179 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer/factory (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient +// + +// Package factory is a generated GoMock package. +package factory + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// NewConnection mocks base method. +func (m *MockDBClient) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockDBClientMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockDBClient)(nil).NewConnection)) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go new file mode 100644 index 00000000..ec2da5cc --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -0,0 +1,186 @@ +/* +Package factory provides a cache factory for the sql-based cache. +*/ +package factory + +import ( + "fmt" + "os" + "sync" + "time" + + "github.com/rancher/lasso/pkg/log" + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/encryption" + "github.com/rancher/steve/pkg/sqlcache/informer" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// EncryptAllEnvVar is set to "true" if users want all types' data blobs to be encrypted in SQLite +// otherwise only variables in defaultEncryptedResourceTypes will have their blobs encrypted +const EncryptAllEnvVar = "CATTLE_ENCRYPT_CACHE_ALL" + +// CacheFactory builds Informer instances and keeps a cache of instances it created +type CacheFactory struct { + wg wait.Group + dbClient DBClient + stopCh chan struct{} + mutex sync.RWMutex + encryptAll bool + + newInformer newInformer + + informers map[schema.GroupVersionKind]*guardedInformer + informersMutex sync.Mutex +} + +type guardedInformer struct { + informer *informer.Informer + mutex *sync.Mutex +} + +type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespace bool) (*informer.Informer, error) + +type DBClient interface { + informer.DBClient + sqlStore.DBClient + connector +} + +type Cache struct { + informer.ByOptionsLister +} + +type connector interface { + NewConnection() error +} + +var defaultEncryptedResourceTypes = map[schema.GroupVersionKind]struct{}{ + { + Version: "v1", + Kind: "Secret", + }: {}, +} + +// NewCacheFactory returns an informer factory instance +func NewCacheFactory() (*CacheFactory, error) { + m, err := encryption.NewManager() + if err != nil { + return nil, err + } + dbClient, err := db.NewClient(nil, m, m) + if err != nil { + return nil, err + } + return &CacheFactory{ + wg: wait.Group{}, + stopCh: make(chan struct{}), + encryptAll: os.Getenv(EncryptAllEnvVar) == "true", + dbClient: dbClient, + newInformer: informer.NewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + }, nil +} + +// CacheFor returns an informer for given GVK, using sql store indexed with fields, using the specified client. For virtual fields, they must be added by the transform function +// and specified by fields to be used for later fields. +func (f *CacheFactory) CacheFor(fields [][]string, transform cache.TransformFunc, client dynamic.ResourceInterface, gvk schema.GroupVersionKind, namespaced bool, watchable bool) (Cache, error) { + // First of all block Reset() until we are done + f.mutex.RLock() + defer f.mutex.RUnlock() + + // Second, check if the informer and its accompanying informer-specific mutex exist already in the informers cache + // If not, start by creating such informer-specific mutex. That is used later to ensure no two goroutines create + // informers for the same GVK at the same type + f.informersMutex.Lock() + // Note: the informers cache is protected by informersMutex, which we don't want to hold for very long because + // that blocks CacheFor for other GVKs, hence not deferring unlock here + gi, ok := f.informers[gvk] + if !ok { + gi = &guardedInformer{ + informer: nil, + mutex: &sync.Mutex{}, + } + f.informers[gvk] = gi + } + f.informersMutex.Unlock() + + // At this point an informer-specific mutex (gi.mutex) is guaranteed to exist. Lock it + gi.mutex.Lock() + defer gi.mutex.Unlock() + + // Then: if the informer really was not created yet (first time here or previous times have errored out) + // actually create the informer + if gi.informer == nil { + start := time.Now() + log.Debugf("CacheFor STARTS creating informer for %v", gvk) + defer func() { + log.Debugf("CacheFor IS DONE creating informer for %v (took %v)", gvk, time.Now().Sub(start)) + }() + + _, encryptResourceAlways := defaultEncryptedResourceTypes[gvk] + shouldEncrypt := f.encryptAll || encryptResourceAlways + i, err := f.newInformer(client, fields, transform, gvk, f.dbClient, shouldEncrypt, namespaced) + if err != nil { + return Cache{}, err + } + + err = i.SetWatchErrorHandler(func(r *cache.Reflector, err error) { + if !watchable && errors.IsMethodNotSupported(err) { + // expected, continue without logging + return + } + cache.DefaultWatchErrorHandler(r, err) + }) + if err != nil { + return Cache{}, err + } + + f.wg.StartWithChannel(f.stopCh, i.Run) + + gi.informer = i + } + + if !cache.WaitForCacheSync(f.stopCh, gi.informer.HasSynced) { + return Cache{}, fmt.Errorf("failed to sync SQLite Informer cache for GVK %v", gvk) + } + + // At this point the informer is ready, return it + return Cache{ByOptionsLister: gi.informer}, nil +} + +// Reset closes the stopCh which stops any running informers, assigns a new stopCh, resets the GVK-informer cache, and resets +// the database connection which wipes any current sqlite database at the default location. +func (f *CacheFactory) Reset() error { + if f.dbClient == nil { + // nothing to reset + return nil + } + + // first of all wait until all CacheFor() calls that create new informers are finished. Also block any new ones + f.mutex.Lock() + defer f.mutex.Unlock() + + // now that we are alone, stop all informers created until this point + close(f.stopCh) + f.stopCh = make(chan struct{}) + f.wg.Wait() + + // and get rid of all references to those informers and their mutexes + f.informersMutex.Lock() + defer f.informersMutex.Unlock() + f.informers = make(map[schema.GroupVersionKind]*guardedInformer) + + // finally, reset the DB connection + err := f.dbClient.NewConnection() + if err != nil { + return err + } + + return nil +} diff --git a/pkg/sqlcache/informer/factory/informer_factory_test.go b/pkg/sqlcache/informer/factory/informer_factory_test.go new file mode 100644 index 00000000..e3b96562 --- /dev/null +++ b/pkg/sqlcache/informer/factory/informer_factory_test.go @@ -0,0 +1,287 @@ +package factory + +import ( + "os" + "testing" + "time" + + "github.com/rancher/steve/pkg/sqlcache/informer" + + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer + +func TestNewCacheFactory(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned, should return no errors", test: func(t *testing.T) { + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.False(t, f.encryptAll) + }}) + tests = append(tests, testCase{description: "NewCacheFactory() with no errors returned and EncryptAllEnvVar set to true, should return no errors and have encryptAll set to true", test: func(t *testing.T) { + err := os.Setenv(EncryptAllEnvVar, "true") + assert.Nil(t, err) + f, err := NewCacheFactory() + assert.Nil(t, err) + assert.Nil(t, err) + assert.NotNil(t, f.dbClient) + assert.True(t, f.encryptAll) + }}) + // cannot run as parallel because tests involve changing env var + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestCacheFor(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh not closed, should return no error and should call Informer.Run(). A subsequent call to CacheFor() should return same informer", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + // this sleep is critical to the test. It ensure there has been enough time for expected function like Run to be invoked in their go routines. + time.Sleep(1 * time.Second) + c2, err := f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, c, c2) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning false, and stopCh not closed, should call Run() and return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(false).AnyTimes() + sii.EXPECT().Run(gomock.Any()) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + expectedI := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return expectedI, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(1 * time.Second) + close(f.stopCh) + }() + var err error + _, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.NotNil(t, err) + time.Sleep(2 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh closed, should not call Run() more than once and not return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true).AnyTimes() + // may or may not call run initially + sii.EXPECT().Run(gomock.Any()).MaxTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + close(f.stopCh) + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned and encryptAll set to true, should return no error and pass shouldEncrypt as true to newInformer func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1).AnyTimes() + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, true, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + encryptAll: true, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + time.Sleep(10 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, nil, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, stopCh not closed, and transform func should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + fields := [][]string{{"something"}} + expectedGVK := schema.GroupVersionKind{} + sii := NewMockSharedIndexInformer(gomock.NewController(t)) + sii.EXPECT().HasSynced().Return(true) + sii.EXPECT().Run(gomock.Any()).MinTimes(1) + sii.EXPECT().SetWatchErrorHandler(gomock.Any()) + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + i := &informer.Informer{ + // need to set this so Run function is not nil + SharedIndexInformer: sii, + } + expectedC := Cache{ + ByOptionsLister: i, + } + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := transform(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + assert.Equal(t, client, dynamicClient) + assert.Equal(t, fields, fields) + assert.Equal(t, expectedGVK, gvk) + assert.Equal(t, db, dbClient) + assert.Equal(t, false, shouldEncrypt) + return i, nil + } + f := &CacheFactory{ + dbClient: dbClient, + stopCh: make(chan struct{}), + newInformer: testNewInformer, + informers: map[schema.GroupVersionKind]*guardedInformer{}, + } + + go func() { + // this function ensures that stopCh is open for the duration of this test but if part of a longer process it will be closed eventually + time.Sleep(5 * time.Second) + close(f.stopCh) + }() + var c Cache + var err error + c, err = f.CacheFor(fields, transformFunc, dynamicClient, expectedGVK, false, true) + assert.Nil(t, err) + assert.Equal(t, expectedC, c) + time.Sleep(1 * time.Second) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go new file mode 100644 index 00000000..b9c4dc35 --- /dev/null +++ b/pkg/sqlcache/informer/factory/k8s_cache_mocks_test.go @@ -0,0 +1,223 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/tools/cache (interfaces: SharedIndexInformer) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer +// + +// Package factory is a generated GoMock package. +package factory + +import ( + reflect "reflect" + time "time" + + gomock "go.uber.org/mock/gomock" + cache "k8s.io/client-go/tools/cache" +) + +// MockSharedIndexInformer is a mock of SharedIndexInformer interface. +type MockSharedIndexInformer struct { + ctrl *gomock.Controller + recorder *MockSharedIndexInformerMockRecorder +} + +// MockSharedIndexInformerMockRecorder is the mock recorder for MockSharedIndexInformer. +type MockSharedIndexInformerMockRecorder struct { + mock *MockSharedIndexInformer +} + +// NewMockSharedIndexInformer creates a new mock instance. +func NewMockSharedIndexInformer(ctrl *gomock.Controller) *MockSharedIndexInformer { + mock := &MockSharedIndexInformer{ctrl: ctrl} + mock.recorder = &MockSharedIndexInformerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSharedIndexInformer) EXPECT() *MockSharedIndexInformerMockRecorder { + return m.recorder +} + +// AddEventHandler mocks base method. +func (m *MockSharedIndexInformer) AddEventHandler(arg0 cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandler", arg0) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandler indicates an expected call of AddEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandler), arg0) +} + +// AddEventHandlerWithResyncPeriod mocks base method. +func (m *MockSharedIndexInformer) AddEventHandlerWithResyncPeriod(arg0 cache.ResourceEventHandler, arg1 time.Duration) (cache.ResourceEventHandlerRegistration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddEventHandlerWithResyncPeriod", arg0, arg1) + ret0, _ := ret[0].(cache.ResourceEventHandlerRegistration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddEventHandlerWithResyncPeriod indicates an expected call of AddEventHandlerWithResyncPeriod. +func (mr *MockSharedIndexInformerMockRecorder) AddEventHandlerWithResyncPeriod(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddEventHandlerWithResyncPeriod", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddEventHandlerWithResyncPeriod), arg0, arg1) +} + +// AddIndexers mocks base method. +func (m *MockSharedIndexInformer) AddIndexers(arg0 cache.Indexers) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddIndexers", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddIndexers indicates an expected call of AddIndexers. +func (mr *MockSharedIndexInformerMockRecorder) AddIndexers(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddIndexers", reflect.TypeOf((*MockSharedIndexInformer)(nil).AddIndexers), arg0) +} + +// GetController mocks base method. +func (m *MockSharedIndexInformer) GetController() cache.Controller { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetController") + ret0, _ := ret[0].(cache.Controller) + return ret0 +} + +// GetController indicates an expected call of GetController. +func (mr *MockSharedIndexInformerMockRecorder) GetController() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetController", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetController)) +} + +// GetIndexer mocks base method. +func (m *MockSharedIndexInformer) GetIndexer() cache.Indexer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIndexer") + ret0, _ := ret[0].(cache.Indexer) + return ret0 +} + +// GetIndexer indicates an expected call of GetIndexer. +func (mr *MockSharedIndexInformerMockRecorder) GetIndexer() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndexer", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetIndexer)) +} + +// GetStore mocks base method. +func (m *MockSharedIndexInformer) GetStore() cache.Store { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStore") + ret0, _ := ret[0].(cache.Store) + return ret0 +} + +// GetStore indicates an expected call of GetStore. +func (mr *MockSharedIndexInformerMockRecorder) GetStore() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStore", reflect.TypeOf((*MockSharedIndexInformer)(nil).GetStore)) +} + +// HasSynced mocks base method. +func (m *MockSharedIndexInformer) HasSynced() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasSynced") + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasSynced indicates an expected call of HasSynced. +func (mr *MockSharedIndexInformerMockRecorder) HasSynced() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasSynced", reflect.TypeOf((*MockSharedIndexInformer)(nil).HasSynced)) +} + +// IsStopped mocks base method. +func (m *MockSharedIndexInformer) IsStopped() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsStopped") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsStopped indicates an expected call of IsStopped. +func (mr *MockSharedIndexInformerMockRecorder) IsStopped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStopped", reflect.TypeOf((*MockSharedIndexInformer)(nil).IsStopped)) +} + +// LastSyncResourceVersion mocks base method. +func (m *MockSharedIndexInformer) LastSyncResourceVersion() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastSyncResourceVersion") + ret0, _ := ret[0].(string) + return ret0 +} + +// LastSyncResourceVersion indicates an expected call of LastSyncResourceVersion. +func (mr *MockSharedIndexInformerMockRecorder) LastSyncResourceVersion() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastSyncResourceVersion", reflect.TypeOf((*MockSharedIndexInformer)(nil).LastSyncResourceVersion)) +} + +// RemoveEventHandler mocks base method. +func (m *MockSharedIndexInformer) RemoveEventHandler(arg0 cache.ResourceEventHandlerRegistration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveEventHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveEventHandler indicates an expected call of RemoveEventHandler. +func (mr *MockSharedIndexInformerMockRecorder) RemoveEventHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveEventHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).RemoveEventHandler), arg0) +} + +// Run mocks base method. +func (m *MockSharedIndexInformer) Run(arg0 <-chan struct{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Run", arg0) +} + +// Run indicates an expected call of Run. +func (mr *MockSharedIndexInformerMockRecorder) Run(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSharedIndexInformer)(nil).Run), arg0) +} + +// SetTransform mocks base method. +func (m *MockSharedIndexInformer) SetTransform(arg0 cache.TransformFunc) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetTransform", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetTransform indicates an expected call of SetTransform. +func (mr *MockSharedIndexInformerMockRecorder) SetTransform(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTransform", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetTransform), arg0) +} + +// SetWatchErrorHandler mocks base method. +func (m *MockSharedIndexInformer) SetWatchErrorHandler(arg0 cache.WatchErrorHandler) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWatchErrorHandler", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWatchErrorHandler indicates an expected call of SetWatchErrorHandler. +func (mr *MockSharedIndexInformerMockRecorder) SetWatchErrorHandler(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWatchErrorHandler", reflect.TypeOf((*MockSharedIndexInformer)(nil).SetWatchErrorHandler), arg0) +} diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go new file mode 100644 index 00000000..f4a99b26 --- /dev/null +++ b/pkg/sqlcache/informer/indexer.go @@ -0,0 +1,263 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "sync" + + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + "k8s.io/client-go/tools/cache" +) + +const ( + selectQueryFmt = ` + SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value IN (?%s) + ) + ` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%[1]s_indices" ( + name TEXT NOT NULL, + value TEXT NOT NULL, + key TEXT NOT NULL REFERENCES "%[1]s"(key) ON DELETE CASCADE, + PRIMARY KEY (name, value, key) + )` + createIndexFmt = `CREATE INDEX IF NOT EXISTS "%[1]s_indices_index" ON "%[1]s_indices"(name, value)` + + deleteIndicesFmt = `DELETE FROM "%s_indices" WHERE key = ?` + addIndexFmt = `INSERT INTO "%s_indices" (name, value, key) VALUES (?, ?, ?) ON CONFLICT DO NOTHING` + listByIndexFmt = `SELECT object, objectnonce, dekid FROM "%[1]s" + WHERE key IN ( + SELECT key FROM "%[1]s_indices" + WHERE name = ? AND value = ? + )` + listKeyByIndexFmt = `SELECT DISTINCT key FROM "%s_indices" WHERE name = ? AND value = ?` + listIndexValuesFmt = `SELECT DISTINCT value FROM "%s_indices" WHERE name = ?` +) + +// Indexer is a SQLite-backed cache.Indexer which builds upon Store adding an index table +type Indexer struct { + Store + indexers cache.Indexers + indexersLock sync.RWMutex + + deleteIndicesQuery string + addIndexQuery string + listByIndexQuery string + listKeysByIndexQuery string + listIndexValuesQuery string + + deleteIndicesStmt *sql.Stmt + addIndexStmt *sql.Stmt + listByIndexStmt *sql.Stmt + listKeysByIndexStmt *sql.Stmt + listIndexValuesStmt *sql.Stmt +} + +var _ cache.Indexer = (*Indexer)(nil) + +type Store interface { + DBClient + cache.Store + + GetByKey(key string) (item any, exists bool, err error) + GetName() string + RegisterAfterUpsert(f func(key string, obj any, tx db.TXClient) error) + RegisterAfterDelete(f func(key string, tx db.TXClient) error) + GetShouldEncrypt() bool + GetType() reflect.Type +} + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Prepare(stmt string) *sql.Stmt + CloseStmt(stmt db.Closable) error +} + +// NewIndexer returns a cache.Indexer backed by SQLite for objects of the given example type +func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + createTableQuery := fmt.Sprintf(createTableFmt, db.Sanitize(s.GetName())) + err = tx.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + createIndexQuery := fmt.Sprintf(createIndexFmt, db.Sanitize(s.GetName())) + err = tx.Exec(createIndexQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createIndexQuery, Err: err} + } + err = tx.Commit() + if err != nil { + return nil, err + } + + i := &Indexer{ + Store: s, + indexers: indexers, + } + i.RegisterAfterUpsert(i.AfterUpsert) + + i.deleteIndicesQuery = fmt.Sprintf(deleteIndicesFmt, db.Sanitize(s.GetName())) + i.addIndexQuery = fmt.Sprintf(addIndexFmt, db.Sanitize(s.GetName())) + i.listByIndexQuery = fmt.Sprintf(listByIndexFmt, db.Sanitize(s.GetName())) + i.listKeysByIndexQuery = fmt.Sprintf(listKeyByIndexFmt, db.Sanitize(s.GetName())) + i.listIndexValuesQuery = fmt.Sprintf(listIndexValuesFmt, db.Sanitize(s.GetName())) + + i.deleteIndicesStmt = s.Prepare(i.deleteIndicesQuery) + i.addIndexStmt = s.Prepare(i.addIndexQuery) + i.listByIndexStmt = s.Prepare(i.listByIndexQuery) + i.listKeysByIndexStmt = s.Prepare(i.listKeysByIndexQuery) + i.listIndexValuesStmt = s.Prepare(i.listIndexValuesQuery) + + return i, nil +} + +/* Core methods */ + +// AfterUpsert updates indices of an object +func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error { + // delete all + err := tx.StmtExec(tx.Stmt(i.deleteIndicesStmt), key) + if err != nil { + return &db.QueryError{QueryString: i.deleteIndicesQuery, Err: err} + } + + // re-insert all + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + for indexName, indexFunc := range i.indexers { + values, err := indexFunc(obj) + if err != nil { + return err + } + + for _, value := range values { + err = tx.StmtExec(tx.Stmt(i.addIndexStmt), indexName, value, key) + if err != nil { + return &db.QueryError{QueryString: i.addIndexQuery, Err: err} + } + } + } + return nil +} + +/* Satisfy cache.Indexer */ + +// Index returns a list of items that match the given object on the index function +func (i *Indexer) Index(indexName string, obj any) ([]any, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("index with name %s does not exist", indexName) + } + + values, err := indexFunc(obj) + if err != nil { + return nil, err + } + + if len(values) == 0 { + return nil, nil + } + + // typical case + if len(values) == 1 { + return i.ByIndex(indexName, values[0]) + } + + // atypical case - more than one value to lookup + // HACK: sql.Statement.Query does not allow to pass slices in as of go 1.19 - create an ad-hoc statement + query := fmt.Sprintf(selectQueryFmt, db.Sanitize(i.GetName()), strings.Repeat(", ?", len(values)-1)) + stmt := i.Prepare(query) + defer i.CloseStmt(stmt) + // HACK: Query will accept []any but not []string + params := []any{indexName} + for _, value := range values { + params = append(params, value) + } + + rows, err := i.QueryForRows(context.TODO(), stmt, params...) + if err != nil { + return nil, &db.QueryError{QueryString: query, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// ByIndex returns the stored objects whose set of indexed values +// for the named index includes the given indexed value +func (i *Indexer) ByIndex(indexName, indexedValue string) ([]any, error) { + rows, err := i.QueryForRows(context.TODO(), i.listByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listByIndexQuery, Err: err} + } + return i.ReadObjects(rows, i.GetType(), i.GetShouldEncrypt()) +} + +// IndexKeys returns a list of the Store keys of the objects whose indexed values in the given index include the given indexed value +func (i *Indexer) IndexKeys(indexName, indexedValue string) ([]string, error) { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + indexFunc := i.indexers[indexName] + if indexFunc == nil { + return nil, fmt.Errorf("Index with name %s does not exist", indexName) + } + + rows, err := i.QueryForRows(context.TODO(), i.listKeysByIndexStmt, indexName, indexedValue) + if err != nil { + return nil, &db.QueryError{QueryString: i.listKeysByIndexQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// ListIndexFuncValues wraps safeListIndexFuncValues and panics in case of I/O errors +func (i *Indexer) ListIndexFuncValues(name string) []string { + result, err := i.safeListIndexFuncValues(name) + if err != nil { + panic(fmt.Errorf("unexpected error in safeListIndexFuncValues: %w", err)) + } + return result +} + +// safeListIndexFuncValues returns all the indexed values of the given index +func (i *Indexer) safeListIndexFuncValues(indexName string) ([]string, error) { + rows, err := i.QueryForRows(context.TODO(), i.listIndexValuesStmt, indexName) + if err != nil { + return nil, &db.QueryError{QueryString: i.listIndexValuesQuery, Err: err} + } + return i.ReadStrings(rows) +} + +// GetIndexers returns the indexers +func (i *Indexer) GetIndexers() cache.Indexers { + i.indexersLock.RLock() + defer i.indexersLock.RUnlock() + return i.indexers +} + +// AddIndexers adds more indexers to this Store. If you call this after you already have data +// in the Store, the results are undefined. +func (i *Indexer) AddIndexers(newIndexers cache.Indexers) error { + i.indexersLock.Lock() + defer i.indexersLock.Unlock() + if i.indexers == nil { + i.indexers = make(map[string]cache.IndexFunc) + } + for k, v := range newIndexers { + i.indexers[k] = v + } + return nil +} diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go new file mode 100644 index 00000000..4118118c --- /dev/null +++ b/pkg/sqlcache/informer/indexer_test.go @@ -0,0 +1,614 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows + +type testStoreObject struct { + Id string + Val string +} + +func TestNewIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or TXClient, should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listKeyByIndexFmt, storeName)) + store.EXPECT().Prepare(fmt.Sprintf(listIndexValuesFmt, storeName)) + indexer, err := NewIndexer(indexers, store) + assert.Nil(t, err) + assert.Equal(t, cache.Indexers(indexers), indexer.indexers) + }}) + tests = append(tests, testCase{description: "NewIndexer() with Store Begin() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + store.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on first call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on second call to Exec(), should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewIndexer() with TXClient Commit() error, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + + objKey := "objKey" + indexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + storeName := "someStoreName" + store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) + client.EXPECT().Commit().Return(fmt.Errorf("error")) + _, err := NewIndexer(indexers, store) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAfterUpsert(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from TXClient should return no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(nil) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.Nil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient StmtExec() should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient second StmtExec() call should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + client := NewMockTXClient(gomock.NewController(t)) + deleteStmt := &sql.Stmt{} + addStmt := &sql.Stmt{} + objKey := "key" + indexer := &Indexer{ + Store: store, + deleteIndicesStmt: deleteStmt, + addIndexStmt: addStmt, + indexers: map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + key := "somekey" + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) + client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) + client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(fmt.Errorf("error")) + testObject := testStoreObject{Id: "something", Val: "a"} + err := indexer.AfterUpsert(key, testObject, client) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "Index() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "Index() where index name is not in indexers, should return error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + _, err := indexer.Index("someotherindexname", testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.Index(indexName, testObject) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "Index() with no errors returned from store and multiple keys returned from index func, should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + indexers: map[string]cache.IndexFunc{ + indexName: func(obj interface{}) ([]string, error) { + return []string{objKey, objKey + "2"}, nil + }, + }, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().GetName().Return("name") + stmt := &sql.Stmt{} + store.EXPECT().Prepare(fmt.Sprintf(selectQueryFmt, "name", ", ?")).Return(stmt) + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + objs, err := indexer.Index(indexName, testObject) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestByIndex(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and multiple objects returned by ReadObjects(), should return multiple objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{testObject, testObject}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with no errors returned from store and no objects returned by ReadObjects(), should return no objects and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil) + objs, err := indexer.ByIndex(indexName, objKey) + assert.Nil(t, err) + assert.Equal(t, []any{}, objs) + }}) + tests = append(tests, testCase{description: "IndexBy() with an error returned from store QueryForRows, should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "IndexBy() with an errors returned from store ReadObjects(), should return an error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + objKey := "key" + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + testObject := testStoreObject{Id: "something", Val: "a"} + + store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil) + store.EXPECT().GetType().Return(reflect.TypeOf(testObject)) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error")) + _, err := indexer.ByIndex(indexName, objKey) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListIndexFuncValues(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListIndexFuncvalues() with no errors returned from store and 1 object returned by ReadObjects(), should return one obj and no error", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, nil) + vals := indexer.ListIndexFuncValues(indexName) + assert.Equal(t, []string{"somestrings"}, vals) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with QueryForRows() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + tests = append(tests, testCase{description: "ListIndexFuncvalues() with ReadStrings() error returned from store, should panic", test: func(t *testing.T) { + store := NewMockStore(gomock.NewController(t)) + rows := &sql.Rows{} + listStmt := &sql.Stmt{} + indexName := "someindexname" + indexer := &Indexer{ + Store: store, + listByIndexStmt: listStmt, + } + store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil) + store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, fmt.Errorf("error")) + assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) }) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestGetIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{ + indexers: expectedIndexers, + } + indexers := indexer.GetIndexers() + assert.Equal(t, cache.Indexers(expectedIndexers), indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestAddIndexers(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "GetIndexers() should return indexers fron indexers field", test: func(t *testing.T) { + objKey := "key" + expectedIndexers := map[string]cache.IndexFunc{ + "a": func(obj interface{}) ([]string, error) { + return []string{objKey}, nil + }, + } + indexer := &Indexer{} + err := indexer.AddIndexers(expectedIndexers) + assert.Nil(t, err) + assert.ObjectsAreEqual(cache.Indexers(expectedIndexers), indexer.indexers) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go new file mode 100644 index 00000000..a74c7029 --- /dev/null +++ b/pkg/sqlcache/informer/informer.go @@ -0,0 +1,94 @@ +/* +package sql provides an Informer and Indexer that uses SQLite as a store, instead of an in-memory store like a map. +*/ + +package informer + +import ( + "context" + "time" + + "github.com/rancher/steve/pkg/sqlcache/partition" + sqlStore "github.com/rancher/steve/pkg/sqlcache/store" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/tools/cache" +) + +// Informer is a SQLite-backed cache.SharedIndexInformer that can execute queries on listprocessor structs +type Informer struct { + cache.SharedIndexInformer + ByOptionsLister +} + +type ByOptionsLister interface { + ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) +} + +// this is set to a var so that it can be overridden by test code for mocking purposes +var newInformer = cache.NewSharedIndexInformer + +// NewInformer returns a new SQLite-backed Informer for the type specified by schema in unstructured.Unstructured form +// using the specified client +func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*Informer, error) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + a, err := client.List(context.Background(), options) + return a, err + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return client.Watch(context.Background(), options) + }, + } + + example := &unstructured.Unstructured{} + example.SetGroupVersionKind(gvk) + + // avoids the informer to periodically resync (re-list) its resources + // currently it is a work hypothesis that, when interacting with the UI, this should not be needed + resyncPeriod := time.Duration(0) + + sii := newInformer(listWatcher, example, resyncPeriod, cache.Indexers{}) + if transform != nil { + if err := sii.SetTransform(transform); err != nil { + return nil, err + } + } + + name := informerNameFromGVK(gvk) + + s, err := sqlStore.NewStore(example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, shouldEncrypt, name) + if err != nil { + return nil, err + } + loi, err := NewListOptionIndexer(fields, s, namespaced) + if err != nil { + return nil, err + } + + // HACK: replace the default informer's indexer with the SQL based one + UnsafeSet(sii, "indexer", loi) + + return &Informer{ + SharedIndexInformer: sii, + ByOptionsLister: loi, + }, nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (i *Informer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + return i.ByOptionsLister.ListByOptions(ctx, lo, partitions, namespace) +} + +func informerNameFromGVK(gvk schema.GroupVersionKind) string { + return gvk.Group + "_" + gvk.Version + "_" + gvk.Kind +} diff --git a/pkg/sqlcache/informer/informer_mocks_test.go b/pkg/sqlcache/informer/informer_mocks_test.go new file mode 100644 index 00000000..9eff0612 --- /dev/null +++ b/pkg/sqlcache/informer/informer_mocks_test.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + reflect "reflect" + + partition "github.com/rancher/steve/pkg/sqlcache/partition" + gomock "go.uber.org/mock/gomock" + unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" +) + +// MockByOptionsLister is a mock of ByOptionsLister interface. +type MockByOptionsLister struct { + ctrl *gomock.Controller + recorder *MockByOptionsListerMockRecorder +} + +// MockByOptionsListerMockRecorder is the mock recorder for MockByOptionsLister. +type MockByOptionsListerMockRecorder struct { + mock *MockByOptionsLister +} + +// NewMockByOptionsLister creates a new mock instance. +func NewMockByOptionsLister(ctrl *gomock.Controller) *MockByOptionsLister { + mock := &MockByOptionsLister{ctrl: ctrl} + mock.recorder = &MockByOptionsListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockByOptionsLister) EXPECT() *MockByOptionsListerMockRecorder { + return m.recorder +} + +// ListByOptions mocks base method. +func (m *MockByOptionsLister) ListByOptions(arg0 context.Context, arg1 ListOptions, arg2 []partition.Partition, arg3 string) (*unstructured.UnstructuredList, int, string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByOptions", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*unstructured.UnstructuredList) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(string) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// ListByOptions indicates an expected call of ListByOptions. +func (mr *MockByOptionsListerMockRecorder) ListByOptions(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByOptions", reflect.TypeOf((*MockByOptionsLister)(nil).ListByOptions), arg0, arg1, arg2, arg3) +} diff --git a/pkg/sqlcache/informer/informer_test.go b/pkg/sqlcache/informer/informer_test.go new file mode 100644 index 00000000..7bb6afd6 --- /dev/null +++ b/pkg/sqlcache/informer/informer_test.go @@ -0,0 +1,345 @@ +package informer + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/rancher/steve/pkg/sqlcache/partition" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/tools/cache" +) + +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient + +func TestNewInformer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "NewInformer() with no errors returned, should return no error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(context.Background(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + informer, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewStore(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with errors returned from NewListOptionIndexer(), should return an error", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewInformer() with transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + txClient := NewMockTXClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{} + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore + // is tested in depth in its own package. + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() + + // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer + // is tested in depth in its own indexer_test.go + dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + informer, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Nil(t, err) + assert.NotNil(t, informer.ByOptionsLister) + assert.NotNil(t, informer.SharedIndexInformer) + assert.NotNil(t, mockInformer.transformFunc) + + // we can't test func == func, so instead we check if the output was as expected + input := "someinput" + ouput, err := mockInformer.transformFunc(input) + assert.Nil(t, err) + outputStr, ok := ouput.(string) + assert.True(t, ok, "ouput from transform was expected to be a string") + assert.Equal(t, "someoutput", outputStr) + + newInformer = cache.NewSharedIndexInformer + }}) + tests = append(tests, testCase{description: "NewInformer() unable to set transform func", test: func(t *testing.T) { + dbClient := NewMockDBClient(gomock.NewController(t)) + dynamicClient := NewMockResourceInterface(gomock.NewController(t)) + mockInformer := mockInformer{ + setTranformErr: fmt.Errorf("some error"), + } + testNewInformer := func(lw cache.ListerWatcher, + exampleObject runtime.Object, + defaultEventHandlerResyncPeriod time.Duration, + indexers cache.Indexers) cache.SharedIndexInformer { + return &mockInformer + } + newInformer = testNewInformer + + fields := [][]string{{"something"}} + gvk := schema.GroupVersionKind{} + + transformFunc := func(input interface{}) (interface{}, error) { + return "someoutput", nil + } + _, err := NewInformer(dynamicClient, fields, transformFunc, gvk, dbClient, false, true) + assert.Error(t, err) + newInformer = cache.NewSharedIndexInformer + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestInformerListByOptions(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListByOptions() with no errors returned, should return no error and return value from indexer's ListByOptions()", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + expectedList := &unstructured.UnstructuredList{ + Object: map[string]interface{}{"s": 2}, + Items: []unstructured.Unstructured{{ + Object: map[string]interface{}{"s": 2}, + }}, + } + expectedTotal := len(expectedList.Items) + expectedContinueToken := "123" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(expectedList, expectedTotal, expectedContinueToken, nil) + list, total, continueToken, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.Nil(t, err) + assert.Equal(t, expectedList, list) + assert.Equal(t, len(expectedList.Items), total) + assert.Equal(t, expectedContinueToken, continueToken) + }}) + tests = append(tests, testCase{description: "ListByOptions() with indexer ListByOptions error, should return error", test: func(t *testing.T) { + indexer := NewMockByOptionsLister(gomock.NewController(t)) + informer := &Informer{ + ByOptionsLister: indexer, + } + lo := ListOptions{} + var partitions []partition.Partition + ns := "somens" + indexer.EXPECT().ListByOptions(context.TODO(), lo, partitions, ns).Return(nil, 0, "", fmt.Errorf("error")) + _, _, _, err := informer.ListByOptions(context.TODO(), lo, partitions, ns) + assert.NotNil(t, err) + }}) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +// Note: SQLite based caching uses an Informer that unsafely sets the Indexer as the ability to set it is not present +// in client-go at the moment. Long term, we look forward contribute a patch to client-go to make that configurable. +// Until then, we are adding this canary test that will panic in case the indexer cannot be set. +func TestUnsafeSet(t *testing.T) { + listWatcher := &cache.ListWatch{ + ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { + return &unstructured.UnstructuredList{}, nil + }, + WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) { + return dummyWatch{}, nil + }, + } + + sii := cache.NewSharedIndexInformer(listWatcher, &unstructured.Unstructured{}, 0, cache.Indexers{}) + + // will panic if SharedIndexInformer stops having a *Indexer field called "indexer" + UnsafeSet(sii, "indexer", &Indexer{}) +} + +type dummyWatch struct{} + +func (dummyWatch) Stop() { +} + +func (dummyWatch) ResultChan() <-chan watch.Event { + result := make(chan watch.Event) + defer close(result) + return result +} + +// mockInformer is a mock of cache.SharedIndexInformer. Unlike other types, we can't generate this using mockgen because we use a unsafeSet to replace the +// indexer field, which is a struct field. This won't exist on the mock, producing an error. So we need to implement our own mock which actually has this field. +type mockInformer struct { + transformFunc cache.TransformFunc + setTranformErr error + indexer cache.Indexer +} + +func (m *mockInformer) AddEventHandler(handler cache.ResourceEventHandler) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) AddEventHandlerWithResyncPeriod(handler cache.ResourceEventHandler, resyncPeriod time.Duration) (cache.ResourceEventHandlerRegistration, error) { + return nil, nil +} +func (m *mockInformer) RemoveEventHandler(handle cache.ResourceEventHandlerRegistration) error { + return nil +} +func (m *mockInformer) GetStore() cache.Store { return nil } +func (m *mockInformer) GetController() cache.Controller { return nil } +func (m *mockInformer) Run(stopCh <-chan struct{}) {} +func (m *mockInformer) HasSynced() bool { return false } +func (m *mockInformer) LastSyncResourceVersion() string { return "" } +func (m *mockInformer) SetWatchErrorHandler(handler cache.WatchErrorHandler) error { return nil } +func (m *mockInformer) IsStopped() bool { return false } +func (m *mockInformer) AddIndexers(indexers cache.Indexers) error { return nil } +func (m *mockInformer) GetIndexer() cache.Indexer { return nil } +func (m *mockInformer) SetTransform(handler cache.TransformFunc) error { + m.transformFunc = handler + return m.setTranformErr +} diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go new file mode 100644 index 00000000..8fc2fd33 --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -0,0 +1,551 @@ +package informer + +import ( + "context" + "database/sql" + "encoding/gob" + "errors" + "fmt" + "regexp" + "sort" + "strconv" + "strings" + + "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/client-go/tools/cache" + + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/partition" +) + +// ListOptionIndexer extends Indexer by allowing queries based on ListOption +type ListOptionIndexer struct { + *Indexer + + namespaced bool + indexedFields []string + + addFieldQuery string + deleteFieldQuery string + + addFieldStmt *sql.Stmt + deleteFieldStmt *sql.Stmt +} + +var ( + defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"} + defaultIndexNamespaced = "metadata.namespace" + subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[a-zA-Z./]+])|(\[[0-9]+])`) + + ErrInvalidColumn = errors.New("supplied column is invalid") +) + +const ( + matchFmt = `%%%s%%` + strictMatchFmt = `%s` + createFieldsTableFmt = `CREATE TABLE "%s_fields" ( + key TEXT NOT NULL PRIMARY KEY, + %s + )` + createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")` + + failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items: %w" +) + +// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK +// ListOptionIndexer is also able to satisfy ListOption queries on indexed (sub)fields +// Fields are specified as slices (eg. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) +func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOptionIndexer, error) { + // necessary in order to gob/ungob unstructured.Unstructured objects + gob.Register(map[string]interface{}{}) + gob.Register([]interface{}{}) + + i, err := NewIndexer(cache.Indexers{}, s) + if err != nil { + return nil, err + } + + var indexedFields []string + for _, f := range defaultIndexedFields { + indexedFields = append(indexedFields, f) + } + if namespaced { + indexedFields = append(indexedFields, defaultIndexNamespaced) + } + for _, f := range fields { + indexedFields = append(indexedFields, toColumnName(f)) + } + + l := &ListOptionIndexer{ + Indexer: i, + namespaced: namespaced, + indexedFields: indexedFields, + } + l.RegisterAfterUpsert(l.afterUpsert) + l.RegisterAfterDelete(l.afterDelete) + columnDefs := make([]string, len(indexedFields)) + for index, field := range indexedFields { + column := fmt.Sprintf(`"%s" TEXT`, field) + columnDefs[index] = column + } + + tx, err := l.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, db.Sanitize(i.GetName()), strings.Join(columnDefs, ", "))) + if err != nil { + return nil, err + } + + columns := make([]string, len(indexedFields)) + qmarks := make([]string, len(indexedFields)) + setStatements := make([]string, len(indexedFields)) + + for index, field := range indexedFields { + // create index for field + err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, db.Sanitize(i.GetName()), field, db.Sanitize(i.GetName()), field)) + if err != nil { + return nil, err + } + + // format field into column for prepared statement + column := fmt.Sprintf(`"%s"`, field) + columns[index] = column + + // add placeholder for column's value in prepared statement + qmarks[index] = "?" + + // add formatted set statement for prepared statement + setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field) + setStatements[index] = setStatement + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + l.addFieldQuery = fmt.Sprintf( + `INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`, + db.Sanitize(i.GetName()), + strings.Join(columns, ", "), + strings.Join(qmarks, ", "), + strings.Join(setStatements, ", "), + ) + l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, db.Sanitize(i.GetName())) + + l.addFieldStmt = l.Prepare(l.addFieldQuery) + l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery) + + return l, nil +} + +/* Core methods */ + +// afterUpsert saves sortable/filterable fields into tables +func (l *ListOptionIndexer) afterUpsert(key string, obj any, tx db.TXClient) error { + args := []any{key} + for _, field := range l.indexedFields { + value, err := getField(obj, field) + if err != nil { + logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err) + } + return err + } + switch typedValue := value.(type) { + case nil: + args = append(args, "") + case int, bool, string: + args = append(args, fmt.Sprint(typedValue)) + case []string: + args = append(args, strings.Join(typedValue, "|")) + default: + err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value) + cErr := tx.Cancel() + if cErr != nil { + return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2) + } + return err2 + } + } + + err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.addFieldQuery, Err: err} + } + return nil +} + +func (l *ListOptionIndexer) afterDelete(key string, tx db.TXClient) error { + args := []any{key} + + err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...) + if err != nil { + return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} + } + return nil +} + +// ListByOptions returns objects according to the specified list options and partitions. +// Specifically: +// - an unstructured list of resources belonging to any of the specified partitions +// - the total number of resources (returned list might be a subset depending on pagination options in lo) +// - a continue token, if there are more pages after the returned one +// - an error instead of all of the above if anything went wrong +func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { + // 1- Intro: SELECT and JOIN clauses + query := fmt.Sprintf(`SELECT o.object, o.objectnonce, o.dekid FROM "%s" o`, db.Sanitize(l.GetName())) + query += "\n " + query += fmt.Sprintf(`JOIN "%s_fields" f ON o.key = f.key`, db.Sanitize(l.GetName())) + params := []any{} + + // 2- Filtering: WHERE clauses (from lo.Filters) + whereClauses := []string{} + for _, orFilters := range lo.Filters { + orClause, orParams, err := l.buildORClauseFromFilters(orFilters) + if err != nil { + return nil, 0, "", err + } + if orClause == "" { + continue + } + whereClauses = append(whereClauses, orClause) + params = append(params, orParams...) + } + + // WHERE clauses (from namespace) + if namespace != "" && namespace != "*" { + whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, namespace) + } + + // WHERE clauses (from partitions and their corresponding parameters) + partitionClauses := []string{} + for _, partition := range partitions { + if partition.Passthrough { + // nothing to do, no extra filtering to apply by definition + } else { + singlePartitionClauses := []string{} + + // filter by namespace + if partition.Namespace != "" && partition.Namespace != "*" { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) + params = append(params, partition.Namespace) + } + + // optionally filter by names + if !partition.All { + names := partition.Names + + if len(names) == 0 { + // degenerate case, there will be no results + singlePartitionClauses = append(singlePartitionClauses, "FALSE") + } else { + singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.name" IN (?%s)`, strings.Repeat(", ?", len(partition.Names)-1))) + // sort for reproducibility + sortedNames := partition.Names.UnsortedList() + sort.Strings(sortedNames) + for _, name := range sortedNames { + params = append(params, name) + } + } + } + + if len(singlePartitionClauses) > 0 { + partitionClauses = append(partitionClauses, strings.Join(singlePartitionClauses, " AND ")) + } + } + } + if len(partitions) == 0 { + // degenerate case, there will be no results + whereClauses = append(whereClauses, "FALSE") + } + if len(partitionClauses) == 1 { + whereClauses = append(whereClauses, partitionClauses[0]) + } + if len(partitionClauses) > 1 { + whereClauses = append(whereClauses, "(\n ("+strings.Join(partitionClauses, ") OR\n (")+")\n)") + } + + if len(whereClauses) > 0 { + query += "\n WHERE\n " + for index, clause := range whereClauses { + query += fmt.Sprintf("(%s)", clause) + if index == len(whereClauses)-1 { + break + } + query += " AND\n " + } + } + + // 2- Sorting: ORDER BY clauses (from lo.Sort) + orderByClauses := []string{} + if len(lo.Sort.PrimaryField) > 0 { + columnName := toColumnName(lo.Sort.PrimaryField) + if err := l.validateColumn(columnName); err != nil { + return nil, 0, "", err + } + + direction := "ASC" + if lo.Sort.PrimaryOrder == DESC { + direction = "DESC" + } + orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction)) + } + if len(lo.Sort.SecondaryField) > 0 { + columnName := toColumnName(lo.Sort.SecondaryField) + if err := l.validateColumn(columnName); err != nil { + return nil, 0, "", err + } + + direction := "ASC" + if lo.Sort.SecondaryOrder == DESC { + direction = "DESC" + } + orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction)) + } + + if len(orderByClauses) > 0 { + query += "\n ORDER BY " + query += strings.Join(orderByClauses, ", ") + } else { + // make sure one default order is always picked + if l.namespaced { + query += "\n ORDER BY f.\"metadata.namespace\" ASC, f.\"metadata.name\" ASC " + } else { + query += "\n ORDER BY f.\"metadata.name\" ASC " + } + } + + // 4- Pagination: LIMIT clause (from lo.Pagination and/or lo.ChunkSize/lo.Resume) + + // before proceeding, save a copy of the query and params without LIMIT/OFFSET + // for COUNTing all results later + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", query) + countParams := params[:] + + limitClause := "" + // take the smallest limit between lo.Pagination and lo.ChunkSize + limit := lo.Pagination.PageSize + if limit == 0 || (lo.ChunkSize > 0 && lo.ChunkSize < limit) { + limit = lo.ChunkSize + } + if limit > 0 { + limitClause = "\n LIMIT ?" + params = append(params, limit) + } + + // OFFSET clause (from lo.Pagination and/or lo.Resume) + offsetClause := "" + offset := 0 + if lo.Resume != "" { + offsetInt, err := strconv.Atoi(lo.Resume) + if err != nil { + return nil, 0, "", err + } + offset = offsetInt + } + if lo.Pagination.Page >= 1 { + offset += lo.Pagination.PageSize * (lo.Pagination.Page - 1) + } + if offset > 0 { + offsetClause = "\n OFFSET ?" + params = append(params, offset) + } + + // assemble and log the final query + query += limitClause + query += offsetClause + logrus.Debugf("ListOptionIndexer prepared statement: %v", query) + logrus.Debugf("Params: %v", params) + + // execute + stmt := l.Prepare(query) + defer l.CloseStmt(stmt) + + tx, err := l.BeginTx(ctx, false) + if err != nil { + return nil, 0, "", err + } + + txStmt := tx.Stmt(stmt) + rows, err := txStmt.QueryContext(ctx, params...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", &db.QueryError{QueryString: query, Err: err} + } + items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt()) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", err + } + + total := len(items) + // if limit or offset were set, execute counting of all rows + if limit > 0 || offset > 0 { + countStmt := l.Prepare(countQuery) + defer l.CloseStmt(countStmt) + txStmt := tx.Stmt(countStmt) + rows, err := txStmt.QueryContext(ctx, countParams...) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", fmt.Errorf("error executing query: %w", err) + } + total, err = l.ReadInt(rows) + if err != nil { + if cerr := tx.Cancel(); cerr != nil { + return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) + } + return nil, 0, "", fmt.Errorf("error reading query results: %w", err) + } + } + if err := tx.Commit(); err != nil { + return nil, 0, "", err + } + + continueToken := "" + if limit > 0 && offset+len(items) < total { + continueToken = fmt.Sprintf("%d", offset+limit) + } + + return toUnstructuredList(items), total, continueToken, nil +} + +func (l *ListOptionIndexer) validateColumn(column string) error { + for _, v := range l.indexedFields { + if v == column { + return nil + } + } + return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn) +} + +// buildORClause creates an SQLite compatible query that ORs conditions built from passed filters +func (l *ListOptionIndexer) buildORClauseFromFilters(orFilters OrFilter) (string, []any, error) { + var orWhereClause string + var params []any + + for index, filter := range orFilters.Filters { + opString := "LIKE" + if filter.Op == NotEq { + opString = "NOT LIKE" + } + columnName := toColumnName(filter.Field) + if err := l.validateColumn(columnName); err != nil { + return "", nil, err + } + + orWhereClause += fmt.Sprintf(`f."%s" %s ? ESCAPE '\'`, columnName, opString) + format := strictMatchFmt + if filter.Partial { + format = matchFmt + } + match := filter.Match + // To allow matches on the backslash itself, the character needs to be replaced first. + // Otherwise, it will undo the following replacements. + match = strings.ReplaceAll(match, `\`, `\\`) + match = strings.ReplaceAll(match, `_`, `\_`) + match = strings.ReplaceAll(match, `%`, `\%`) + params = append(params, fmt.Sprintf(format, match)) + if index == len(orFilters.Filters)-1 { + continue + } + orWhereClause += " OR " + } + return orWhereClause, params, nil +} + +// toColumnName returns the column name corresponding to a field expressed as string slice +func toColumnName(s []string) string { + return db.Sanitize(strings.Join(s, ".")) +} + +// getField extracts the value of a field expressed as a string path from an unstructured object +func getField(a any, field string) (any, error) { + subFields := extractSubFields(field) + o, ok := a.(*unstructured.Unstructured) + if !ok { + return nil, fmt.Errorf("unexpected object type, expected unstructured.Unstructured: %v", a) + } + + var obj interface{} + var found bool + var err error + obj = o.Object + for i, subField := range subFields { + switch t := obj.(type) { + case map[string]interface{}: + subField = strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]") + obj, found, err = unstructured.NestedFieldNoCopy(t, subField) + if err != nil { + return nil, err + } + if !found { + // particularly with labels/annotation indexes, it is totally possible that some objects won't have these, + // so either we this is not an error state or it could be an error state with a type that callers can check for + return nil, nil + } + case []interface{}: + if strings.HasPrefix(subField, "[") && strings.HasSuffix(subField, "]") { + key, err := strconv.Atoi(strings.TrimSuffix(strings.TrimPrefix(subField, "["), "]")) + if err != nil { + return nil, fmt.Errorf("[listoption indexer] failed to convert subfield [%s] to int in listoption index: %w", subField, err) + } + if key >= len(t) { + return nil, fmt.Errorf("[listoption indexer] given index is too large for slice of len %d", len(t)) + } + obj = fmt.Sprintf("%v", t[key]) + } else if i == len(subFields)-1 { + result := make([]string, len(t)) + for index, v := range t { + itemVal, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + itemStr, ok := itemVal[subField].(string) + if !ok { + return nil, fmt.Errorf(failedToGetFromSliceFmt, subField, err) + } + result[index] = itemStr + } + return result, nil + } + default: + return nil, fmt.Errorf("[listoption indexer] failed to parse subfields: %v", subFields) + } + } + return obj, nil +} + +func extractSubFields(fields string) []string { + subfields := make([]string, 0) + for _, subField := range subfieldRegex.FindAllString(fields, -1) { + subfields = append(subfields, strings.TrimSuffix(subField, ".")) + } + return subfields +} + +// toUnstructuredList turns a slice of unstructured objects into an unstructured.UnstructuredList +func toUnstructuredList(items []any) *unstructured.UnstructuredList { + objectItems := make([]map[string]any, len(items)) + result := &unstructured.UnstructuredList{ + Items: make([]unstructured.Unstructured, len(items)), + Object: map[string]interface{}{"items": objectItems}, + } + for i, item := range items { + result.Items[i] = *item.(*unstructured.Unstructured) + objectItems[i] = item.(*unstructured.Unstructured).Object + } + return result +} diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go new file mode 100644 index 00000000..a84ff727 --- /dev/null +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -0,0 +1,751 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/rancher/steve/pkg/sqlcache/partition" +) + +func TestNewListOptionIndexer(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T) + } + + var tests []testCase + tests = append(tests, testCase{description: "NewListOptionIndexer() with no errors returned, should return no error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterDelete(gomock.Any()) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + // create field table + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + // create field table indexes + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Commit().Return(nil) + + loi, err := NewListOptionIndexer(fields, store, true) + assert.Nil(t, err) + assert.NotNil(t, loi) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from NewIndxer(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error returned from Begin(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterDelete(gomock.Any()) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, false) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Exec() when creating fields table, should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterDelete(gomock.Any()) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + tests = append(tests, testCase{description: "NewListOptionIndexer() with error from Commit(), should return an error", test: func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + fields := [][]string{{"something"}} + id := "somename" + stmt := &sql.Stmt{} + // logic for NewIndexer(), only interested in if this results in error or not + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + store.EXPECT().GetName().Return(id).AnyTimes() + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Exec(gomock.Any(), gomock.Any()).Return(nil) + txClient.EXPECT().Commit().Return(nil) + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() + // end NewIndexer() logic + + store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterDelete(gomock.Any()) + + store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) + txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + + _, err := NewListOptionIndexer(fields, store, true) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t) }) + } +} + +func TestListByOptions(t *testing.T) { + type testCase struct { + description string + listOptions ListOptions + partitions []partition.Partition + ns string + expectedCountStmt string + expectedCountStmtArgs []any + expectedStmt string + expectedStmtArgs []any + expectedList *unstructured.UnstructuredList + returnList []any + expectedContToken string + expectedErr error + } + + testObject := testStoreObject{Id: "something", Val: "a"} + unstrTestObjectMap, err := runtime.DefaultUnstructuredConverter.ToUnstructured(&testObject) + assert.Nil(t, err) + // unstrTestObject + var tests []testCase + tests = append(tests, testCase{ + description: "ListByOptions() with no errors returned, should not return an error", + listOptions: ListOptions{}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions() with an empty filter, should not return an error", + listOptions: ListOptions{ + Filters: []OrFilter{{[]Filter{}}}, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{}}, Items: []unstructured.Unstructured{}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with ChunkSize set should set limit in prepared sql.Stmt", + listOptions: ListOptions{ChunkSize: 2}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []interface{}{2}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC )`, + expectedCountStmtArgs: []interface{}{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Resume set should set offset in prepared sql.Stmt", + listOptions: ListOptions{Resume: "4"}, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + OFFSET ?`, + expectedStmtArgs: []interface{}{4}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC )`, + expectedCountStmtArgs: []interface{}{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter should select where that filter is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somevalue", + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Op set top NotEq should select where that filter is not true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somevalue", + Op: NotEq, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somevalue"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with 1 filter with Partial set to true should select where that partial match on that filter's value is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somevalue", + Partial: true, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with 1 OrFilter set with multiple filters should select where any of those filters are true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somevalue", + Partial: true, + }, + { + Field: []string{"metadata", "somefield"}, + Match: "someothervalue", + }, + { + Field: []string{"metadata", "somefield"}, + Match: "somethirdvalue", + Op: NotEq, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\' OR f."metadata.somefield" LIKE ? ESCAPE '\' OR f."metadata.somefield" NOT LIKE ? ESCAPE '\') AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%", "someothervalue", "somethirdvalue"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with multiple OrFilters set should select where all OrFilters contain one filter that is true in prepared sql.Stmt", + listOptions: ListOptions{Filters: []OrFilter{ + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somevalue", + Partial: true, + }, + { + Field: []string{"status", "someotherfield"}, + Match: "someothervalue", + Op: NotEq, + }, + }, + }, + { + Filters: []Filter{ + { + Field: []string{"metadata", "somefield"}, + Match: "somethirdvalue", + Op: Eq, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "test4", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.somefield" LIKE ? ESCAPE '\' OR f."status.someotherfield" NOT LIKE ? ESCAPE '\') AND + (f."metadata.somefield" LIKE ? ESCAPE '\') AND + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"%somevalue%", "someothervalue", "somethirdvalue", "test4"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.PrimaryField set only should sort on that field only, in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + PrimaryField: []string{"metadata", "somefield"}, + }, + }, + partitions: []partition.Partition{}, + ns: "test5", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ?) AND + (FALSE) + ORDER BY f."metadata.somefield" ASC`, + expectedStmtArgs: []any{"test5"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.SecondaryField set only should sort on that field only, in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + SecondaryField: []string{"metadata", "somefield"}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.PrimaryField and Sort.SecondaryField set should sort on PrimaryField in ascending order first and then sort on SecondaryField in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + PrimaryField: []string{"metadata", "somefield"}, + SecondaryField: []string{"status", "someotherfield"}, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" ASC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.PrimaryField and Sort.SecondaryField set and PrimaryOrder set to DESC should sort on PrimaryField in descending order first and then sort on SecondaryField in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + PrimaryField: []string{"metadata", "somefield"}, + SecondaryField: []string{"status", "someotherfield"}, + PrimaryOrder: DESC, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.somefield" DESC, f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.SecondaryField set and Sort.PrimaryOrder set to descending should sort on that SecondaryField in ascending order only and ignore PrimaryOrder in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + SecondaryField: []string{"status", "someotherfield"}, + PrimaryOrder: DESC, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."status.someotherfield" ASC`, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Sort.PrimaryOrder set only should sort on default primary and secondary fields in ascending order in prepared sql.Stmt", + listOptions: ListOptions{ + Sort: Sort{ + PrimaryOrder: DESC, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.PageSize set should set limit to PageSize in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ?`, + expectedStmtArgs: []any{10}, + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC )`, + expectedCountStmtArgs: []interface{}{}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and no PageSize set should not add anything to prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with Pagination.Page and PageSize set limit to PageSize and offset to PageSize * (Page - 1) in prepared sql.Stmt", + listOptions: ListOptions{ + Pagination: Pagination{ + PageSize: 10, + Page: 2, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC + LIMIT ? + OFFSET ?`, + expectedStmtArgs: []any{10, 10}, + + expectedCountStmt: `SELECT COUNT(*) FROM (SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY f."metadata.name" ASC )`, + expectedCountStmtArgs: []interface{}{}, + + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Namespace Partition should select only items where metadata.namespace is equal to Namespace and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Namespace: "somens", + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.namespace" = ? AND FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"somens"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a All Partition should select all items that meet all other conditions in prepared sql.Stmt", + partitions: []partition.Partition{ + { + All: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Passthrough Partition should select all items that meet all other conditions prepared sql.Stmt", + partitions: []partition.Partition{ + { + Passthrough: true, + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + ORDER BY f."metadata.name" ASC `, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "ListByOptions with a Names Partition should select only items where metadata.name equals an items in Names and all other conditions are met in prepared sql.Stmt", + partitions: []partition.Partition{ + { + Names: sets.New[string]("someid", "someotherid"), + }, + }, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (f."metadata.name" IN (?, ?)) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"someid", "someotherid"}, + returnList: []any{&unstructured.Unstructured{Object: unstrTestObjectMap}, &unstructured.Unstructured{Object: unstrTestObjectMap}}, + expectedList: &unstructured.UnstructuredList{Object: map[string]interface{}{"items": []map[string]interface{}{unstrTestObjectMap, unstrTestObjectMap}}, Items: []unstructured.Unstructured{{Object: unstrTestObjectMap}, {Object: unstrTestObjectMap}}}, + expectedContToken: "", + expectedErr: nil, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + txClient := NewMockTXClient(gomock.NewController(t)) + store := NewMockStore(gomock.NewController(t)) + stmts := NewMockStmt(gomock.NewController(t)) + i := &Indexer{ + Store: store, + } + lii := &ListOptionIndexer{ + Indexer: i, + indexedFields: []string{"metadata.somefield", "status.someotherfield"}, + } + stmt := &sql.Stmt{} + rows := &sql.Rows{} + objType := reflect.TypeOf(testObject) + store.EXPECT().BeginTx(gomock.Any(), false).Return(txClient, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts).AnyTimes() + store.EXPECT().GetName().Return("something").AnyTimes() + store.EXPECT().Prepare(test.expectedStmt).Do(func(a ...any) { + fmt.Println(a) + }).Return(stmt) + if args := test.expectedStmtArgs; args != nil { + stmts.EXPECT().QueryContext(gomock.Any(), gomock.Any()).Return(rows, nil).AnyTimes() + } else if strings.Contains(test.expectedStmt, "LIMIT") { + stmts.EXPECT().QueryContext(gomock.Any(), args...).Return(rows, nil) + txClient.EXPECT().Stmt(gomock.Any()).Return(stmts) + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } else { + stmts.EXPECT().QueryContext(gomock.Any()).Return(rows, nil) + } + store.EXPECT().GetType().Return(objType) + store.EXPECT().GetShouldEncrypt().Return(false) + store.EXPECT().ReadObjects(rows, objType, false).Return(test.returnList, nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + + if test.expectedCountStmt != "" { + store.EXPECT().Prepare(test.expectedCountStmt).Return(stmt) + //store.EXPECT().QueryForRows(context.TODO(), stmt, test.expectedCountStmtArgs...).Return(rows, nil) + store.EXPECT().ReadInt(rows).Return(len(test.expectedList.Items), nil) + store.EXPECT().CloseStmt(stmt).Return(nil) + } + txClient.EXPECT().Commit() + list, total, contToken, err := lii.ListByOptions(context.TODO(), test.listOptions, test.partitions, test.ns) + if test.expectedErr == nil { + assert.Nil(t, err) + } else { + assert.Equal(t, test.expectedErr, err) + } + assert.Equal(t, test.expectedList, list) + assert.Equal(t, len(test.expectedList.Items), total) + assert.Equal(t, test.expectedContToken, contToken) + }) + } +} diff --git a/pkg/sqlcache/informer/listoptions.go b/pkg/sqlcache/informer/listoptions.go new file mode 100644 index 00000000..8b894d1a --- /dev/null +++ b/pkg/sqlcache/informer/listoptions.go @@ -0,0 +1,59 @@ +package informer + +type Op string + +const ( + Eq Op = "" + NotEq Op = "!=" +) + +// SortOrder represents whether the list should be ascending or descending. +type SortOrder int + +const ( + // ASC stands for ascending order. + ASC SortOrder = iota + // DESC stands for descending (reverse) order. + DESC +) + +// ListOptions represents the query parameters that may be included in a list request. +type ListOptions struct { + ChunkSize int + Resume string + Filters []OrFilter + Sort Sort + Pagination Pagination +} + +// Filter represents a field to filter by. +// A subfield in an object is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +type Filter struct { + Field []string + Match string + Op Op + Partial bool +} + +// OrFilter represents a set of possible fields to filter by, where an item may match any filter in the set to be included in the result. +type OrFilter struct { + Filters []Filter +} + +// Sort represents the criteria to sort on. +// The subfield to sort by is represented in a request query using . notation, e.g. 'metadata.name'. +// The subfield is internally represented as a slice, e.g. [metadata, name]. +// The order is represented by prefixing the sort key by '-', e.g. sort=-metadata.name. +type Sort struct { + PrimaryField []string + SecondaryField []string + PrimaryOrder SortOrder + SecondaryOrder SortOrder +} + +// Pagination represents how to return paginated results. +type Pagination struct { + PageSize int + Page int +} diff --git a/pkg/sqlcache/informer/shared_informer_hack.go b/pkg/sqlcache/informer/shared_informer_hack.go new file mode 100644 index 00000000..c11889c9 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_hack.go @@ -0,0 +1,22 @@ +package informer + +import ( + "reflect" + "unsafe" +) + +// UnsafeSet replaces the passed object's field value with the passed value. +func UnsafeSet(object any, field string, value any) { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + wrf.Set(reflect.ValueOf(value)) +} + +// UnsafeGet returns the value of the passed object's for the passed field. +func UnsafeGet(object any, field string) any { + rs := reflect.ValueOf(object).Elem() + rf := rs.FieldByName(field) + wrf := reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() + return wrf.Interface() +} diff --git a/pkg/sqlcache/informer/shared_informer_test.go b/pkg/sqlcache/informer/shared_informer_test.go new file mode 100644 index 00000000..bd143647 --- /dev/null +++ b/pkg/sqlcache/informer/shared_informer_test.go @@ -0,0 +1,325 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package informer + +import ( + "fmt" + "k8s.io/client-go/tools/cache" + "strings" + "sync" + "testing" + "time" + + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + fcache "k8s.io/client-go/tools/cache/testing" + testingclock "k8s.io/utils/clock/testing" +) + +type testListener struct { + lock sync.RWMutex + resyncPeriod time.Duration + expectedItemNames sets.Set[string] + receivedItemNames []string + name string +} + +func newTestListener(name string, resyncPeriod time.Duration, expected ...string) *testListener { + l := &testListener{ + resyncPeriod: resyncPeriod, + expectedItemNames: sets.New[string](expected...), + name: name, + } + return l +} + +func (l *testListener) OnAdd(obj interface{}, isInInitialList bool) { + l.handle(obj) +} + +func (l *testListener) OnUpdate(old, new interface{}) { + l.handle(new) +} + +func (l *testListener) OnDelete(obj interface{}) { +} + +func (l *testListener) handle(obj interface{}) { + key, _ := cache.MetaNamespaceKeyFunc(obj) + fmt.Printf("%s: handle: %v\n", l.name, key) + l.lock.Lock() + defer l.lock.Unlock() + + objectMeta, _ := meta.Accessor(obj) + l.receivedItemNames = append(l.receivedItemNames, objectMeta.GetName()) +} + +func (l *testListener) ok() bool { + fmt.Println("polling") + err := wait.PollImmediate(100*time.Millisecond, 2*time.Second, func() (bool, error) { + if l.satisfiedExpectations() { + return true, nil + } + return false, nil + }) + if err != nil { + return false + } + + // wait just a bit to allow any unexpected stragglers to come in + fmt.Println("sleeping") + time.Sleep(1 * time.Second) + fmt.Println("final check") + return l.satisfiedExpectations() +} + +func (l *testListener) satisfiedExpectations() bool { + l.lock.RLock() + defer l.lock.RUnlock() + + return sets.New[string](l.receivedItemNames...).Equal(l.expectedItemNames) +} + +func TestListenerResyncPeriods(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener 1, never resync + listener1 := newTestListener("listener1", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener1, listener1.resyncPeriod) + + // listener 2, resync every 2s + listener2 := newTestListener("listener2", 2*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener2, listener2.resyncPeriod) + + // listener 3, resync every 3s + listener3 := newTestListener("listener3", 3*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listener3, listener3.resyncPeriod) + listeners := []*testListener{listener1, listener2, listener3} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + // ensure all listeners got the initial List + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener2 gets a resync + clock.Step(2 * time.Second) + + // make sure listener2 got the resync + if !listener2.ok() { + t.Errorf("%s: expected %v, got %v", listener2.name, listener2.expectedItemNames, listener2.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 3 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 3 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener3.receivedItemNames) != 0 { + t.Errorf("listener3: should not have resynced (got %d)", len(listener3.receivedItemNames)) + } + + // reset + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + // advance so listener3 gets a resync + clock.Step(1 * time.Second) + + // make sure listener3 got the resync + if !listener3.ok() { + t.Errorf("%s: expected %v, got %v", listener3.name, listener3.expectedItemNames, listener3.receivedItemNames) + } + + // wait a bit to give errant items a chance to go to 1 and 2 + time.Sleep(1 * time.Second) + + // make sure listeners 1 and 2 got nothing + if len(listener1.receivedItemNames) != 0 { + t.Errorf("listener1: should not have resynced (got %d)", len(listener1.receivedItemNames)) + } + if len(listener2.receivedItemNames) != 0 { + t.Errorf("listener2: should not have resynced (got %d)", len(listener2.receivedItemNames)) + } +} + +// verify that https://github.com/kubernetes/kubernetes/issues/59822 is fixed +func TestSharedInformerInitializationRace(t *testing.T) { + source := fcache.NewFakeControllerSource() + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + listener := newTestListener("raceListener", 0) + + stop := make(chan struct{}) + go informer.AddEventHandlerWithResyncPeriod(listener, listener.resyncPeriod) + go informer.Run(stop) + close(stop) +} + +// TestSharedInformerWatchDisruption simulates a watch that was closed +// with updates to the store during that time. We ensure that handlers with +// resync and no resync see the expected state. +func TestSharedInformerWatchDisruption(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + // create the shared informer and resync every 1s + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + clock := testingclock.NewFakeClock(time.Now()) + UnsafeSet(informer, "clock", clock) + UnsafeSet(UnsafeGet(informer, "processor"), "clock", clock) + + // listener, never resync + listenerNoResync := newTestListener("listenerNoResync", 0, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerNoResync, listenerNoResync.resyncPeriod) + + listenerResync := newTestListener("listenerResync", 1*time.Second, "pod1", "pod2") + informer.AddEventHandlerWithResyncPeriod(listenerResync, listenerResync.resyncPeriod) + listeners := []*testListener{listenerNoResync, listenerResync} + + stop := make(chan struct{}) + defer close(stop) + + go informer.Run(stop) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + // Add pod3, bump pod2 but don't broadcast it, so that the change will be seen only on relist + source.AddDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod3", UID: "pod3", ResourceVersion: "3"}}) + source.ModifyDropWatch(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "4"}}) + + // Ensure that nobody saw any changes + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } + + for _, listener := range listeners { + listener.receivedItemNames = []string{} + } + + listenerNoResync.expectedItemNames = sets.New[string]("pod2", "pod3") + listenerResync.expectedItemNames = sets.New[string]("pod1", "pod2", "pod3") + + // This calls shouldSync, which deletes noResync from the list of syncingListeners + clock.Step(1 * time.Second) + + // Simulate a connection loss (or even just a too-old-watch) + source.ResetWatch() + + // Wait long enough for the reflector to exit and the backoff function to start waiting + // on the fake clock, otherwise advancing the fake clock will have no effect. + // TODO: Make this deterministic by counting the number of waiters on FakeClock + time.Sleep(10 * time.Millisecond) + + // Advance the clock to cause the backoff wait to expire. + clock.Step(1601 * time.Millisecond) + + // Wait long enough for backoff to invoke ListWatch a second time and distribute events + // to listeners. + time.Sleep(10 * time.Millisecond) + + for _, listener := range listeners { + if !listener.ok() { + t.Errorf("%s: expected %v, got %v", listener.name, listener.expectedItemNames, listener.receivedItemNames) + } + } +} + +func TestSharedInformerErrorHandling(t *testing.T) { + source := fcache.NewFakeControllerSource() + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}) + source.ListError = fmt.Errorf("Access Denied") + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + + errCh := make(chan error) + _ = informer.SetWatchErrorHandler(func(_ *cache.Reflector, err error) { + errCh <- err + }) + + stop := make(chan struct{}) + go informer.Run(stop) + + select { + case err := <-errCh: + if !strings.Contains(err.Error(), "Access Denied") { + t.Errorf("Expected 'Access Denied' error. Actual: %v", err) + } + case <-time.After(time.Second): + t.Errorf("Timeout waiting for error handler call") + } + close(stop) +} + +func TestSharedInformerTransformer(t *testing.T) { + // source simulates an apiserver object endpoint. + source := fcache.NewFakeControllerSource() + + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod1", UID: "pod1", ResourceVersion: "1"}}) + source.Add(&v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "pod2", UID: "pod2", ResourceVersion: "2"}}) + + informer := cache.NewSharedInformer(source, &v1.Pod{}, 1*time.Second) + informer.SetTransform(func(obj interface{}) (interface{}, error) { + if pod, ok := obj.(*v1.Pod); ok { + name := pod.GetName() + + if upper := strings.ToUpper(name); upper != name { + copied := pod.DeepCopyObject().(*v1.Pod) + copied.SetName(upper) + return copied, nil + } + } + return obj, nil + }) + + listenerTransformer := newTestListener("listenerTransformer", 0, "POD1", "POD2") + informer.AddEventHandler(listenerTransformer) + + stop := make(chan struct{}) + go informer.Run(stop) + defer close(stop) + + if !listenerTransformer.ok() { + t.Errorf("%s: expected %v, got %v", listenerTransformer.name, listenerTransformer.expectedItemNames, listenerTransformer.receivedItemNames) + } +} diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go new file mode 100644 index 00000000..c269b01b --- /dev/null +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -0,0 +1,347 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: Store) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockStore) Add(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockStoreMockRecorder) Add(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), arg0) +} + +// BeginTx mocks base method. +func (m *MockStore) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockStoreMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockStore)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockStore) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockStoreMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockStore)(nil).CloseStmt), arg0) +} + +// Delete mocks base method. +func (m *MockStore) Delete(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockStoreMockRecorder) Delete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), arg0) +} + +// Get mocks base method. +func (m *MockStore) Get(arg0 any) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Get indicates an expected call of Get. +func (mr *MockStoreMockRecorder) Get(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStore)(nil).Get), arg0) +} + +// GetByKey mocks base method. +func (m *MockStore) GetByKey(arg0 string) (any, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByKey", arg0) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetByKey indicates an expected call of GetByKey. +func (mr *MockStoreMockRecorder) GetByKey(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByKey", reflect.TypeOf((*MockStore)(nil).GetByKey), arg0) +} + +// GetName mocks base method. +func (m *MockStore) GetName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetName indicates an expected call of GetName. +func (mr *MockStoreMockRecorder) GetName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockStore)(nil).GetName)) +} + +// GetShouldEncrypt mocks base method. +func (m *MockStore) GetShouldEncrypt() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetShouldEncrypt") + ret0, _ := ret[0].(bool) + return ret0 +} + +// GetShouldEncrypt indicates an expected call of GetShouldEncrypt. +func (mr *MockStoreMockRecorder) GetShouldEncrypt() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShouldEncrypt", reflect.TypeOf((*MockStore)(nil).GetShouldEncrypt)) +} + +// GetType mocks base method. +func (m *MockStore) GetType() reflect.Type { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetType") + ret0, _ := ret[0].(reflect.Type) + return ret0 +} + +// GetType indicates an expected call of GetType. +func (mr *MockStoreMockRecorder) GetType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetType", reflect.TypeOf((*MockStore)(nil).GetType)) +} + +// List mocks base method. +func (m *MockStore) List() []any { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List") + ret0, _ := ret[0].([]any) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockStoreMockRecorder) List() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStore)(nil).List)) +} + +// ListKeys mocks base method. +func (m *MockStore) ListKeys() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListKeys") + ret0, _ := ret[0].([]string) + return ret0 +} + +// ListKeys indicates an expected call of ListKeys. +func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKeys", reflect.TypeOf((*MockStore)(nil).ListKeys)) +} + +// Prepare mocks base method. +func (m *MockStore) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockStoreMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockStore)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockStore) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockStoreMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockStore)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockStore) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockStoreMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockStore)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockStore) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockStoreMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockStore)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockStore) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockStore)(nil).ReadStrings), arg0) +} + +// RegisterAfterDelete mocks base method. +func (m *MockStore) RegisterAfterDelete(arg0 func(string, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterDelete", arg0) +} + +// RegisterAfterDelete indicates an expected call of RegisterAfterDelete. +func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) +} + +// RegisterAfterUpsert mocks base method. +func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, db.TXClient) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterUpsert", arg0) +} + +// RegisterAfterUpsert indicates an expected call of RegisterAfterUpsert. +func (mr *MockStoreMockRecorder) RegisterAfterUpsert(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpsert", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpsert), arg0) +} + +// Replace mocks base method. +func (m *MockStore) Replace(arg0 []any, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Replace", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Replace indicates an expected call of Replace. +func (mr *MockStoreMockRecorder) Replace(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockStore)(nil).Replace), arg0, arg1) +} + +// Resync mocks base method. +func (m *MockStore) Resync() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Resync") + ret0, _ := ret[0].(error) + return ret0 +} + +// Resync indicates an expected call of Resync. +func (mr *MockStoreMockRecorder) Resync() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Resync", reflect.TypeOf((*MockStore)(nil).Resync)) +} + +// Update mocks base method. +func (m *MockStore) Update(arg0 any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockStoreMockRecorder) Update(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), arg0) +} diff --git a/pkg/sqlcache/informer/store_mocks_test.go b/pkg/sqlcache/informer/store_mocks_test.go new file mode 100644 index 00000000..c1c7d426 --- /dev/null +++ b/pkg/sqlcache/informer/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/informer/tx_mocks_test.go b/pkg/sqlcache/informer/tx_mocks_test.go new file mode 100644 index 00000000..e482df9b --- /dev/null +++ b/pkg/sqlcache/informer/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/lasso/pkg/cache/sql/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/db/transaction Stmt +// + +// Package informer is a generated GoMock package. +package informer + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} diff --git a/pkg/sqlcache/integration_test.go b/pkg/sqlcache/integration_test.go new file mode 100644 index 00000000..3f5f82e0 --- /dev/null +++ b/pkg/sqlcache/integration_test.go @@ -0,0 +1,365 @@ +package sql + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/envtest" + + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" +) + +const testNamespace = "sql-test" + +var defaultPartition = partition.Partition{ + All: true, +} + +type IntegrationSuite struct { + suite.Suite + testEnv envtest.Environment + clientset kubernetes.Clientset + restCfg rest.Config +} + +func (i *IntegrationSuite) SetupSuite() { + i.testEnv = envtest.Environment{} + restCfg, err := i.testEnv.Start() + i.Require().NoError(err, "error when starting env test - this is likely because setup-envtest wasn't done. Check the README for more information") + i.restCfg = *restCfg + clientset, err := kubernetes.NewForConfig(restCfg) + i.Require().NoError(err) + i.clientset = *clientset + testNs := v1.Namespace{ + ObjectMeta: metav1.ObjectMeta{ + Name: testNamespace, + }, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _, err = i.clientset.CoreV1().Namespaces().Create(ctx, &testNs, metav1.CreateOptions{}) + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TearDownSuite() { + err := i.testEnv.Stop() + i.Require().NoError(err) +} + +func (i *IntegrationSuite) TestSQLCacheFilters() { + fields := [][]string{{`metadata`, `annotations[somekey]`}} + require := i.Require() + configMapWithAnnotations := func(name string, annotations map[string]string) v1.ConfigMap { + return v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: testNamespace, + Annotations: annotations, + }, + } + } + createConfigMaps := func(configMaps ...v1.ConfigMap) { + for _, configMap := range configMaps { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + configMapClient := i.clientset.CoreV1().ConfigMaps(testNamespace) + _, err := configMapClient.Create(ctx, &configMap, metav1.CreateOptions{}) + require.NoError(err) + // avoiding defer in a for loop + cancel() + } + } + + // we create some configmaps before the cache starts and some after so that we can test the initial list and + // subsequent watches to make sure both work + // matches the filter for somekey == somevalue + matches := configMapWithAnnotations("matches-filter", map[string]string{"somekey": "somevalue"}) + // partial match for somekey == somevalue (different suffix) + partialMatches := configMapWithAnnotations("partial-matches", map[string]string{"somekey": "somevaluehere"}) + specialCharacterMatch := configMapWithAnnotations("special-character-matches", map[string]string{"somekey": "c%%l_value"}) + backSlashCharacterMatch := configMapWithAnnotations("backslash-character-matches", map[string]string{"somekey": `my\windows\path`}) + createConfigMaps(matches, partialMatches, specialCharacterMatch, backSlashCharacterMatch) + + cache, cacheFactory, err := i.createCacheAndFactory(fields, nil) + require.NoError(err) + defer cacheFactory.Reset() + + // doesn't match the filter for somekey == somevalue + notMatches := configMapWithAnnotations("not-matches-filter", map[string]string{"somekey": "notequal"}) + // has no annotations, shouldn't match any filter + missing := configMapWithAnnotations("missing", nil) + createConfigMaps(notMatches, missing) + + configMapNames := []string{matches.Name, partialMatches.Name, notMatches.Name, missing.Name, specialCharacterMatch.Name, backSlashCharacterMatch.Name} + err = i.waitForCacheReady(configMapNames, testNamespace, cache) + require.NoError(err) + + orFiltersForFilters := func(filters ...informer.Filter) []informer.OrFilter { + return []informer.OrFilter{ + { + Filters: filters, + }, + } + } + tests := []struct { + name string + filters []informer.OrFilter + wantNames []string + }{ + { + name: "matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.Eq, + Partial: false, + }), + wantNames: []string{"matches-filter"}, + }, + { + name: "partial matches filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"matches-filter", "partial-matches"}, + }, + { + name: "no matches for filter with underscore as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalu_", + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "no matches for filter with percent sign as it is interpreted literally", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalu%", + Op: informer.Eq, + Partial: true, + }), + wantNames: nil, + }, + { + name: "match with special characters", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "c%%l_value", + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"special-character-matches"}, + }, + { + name: "match with literal backslash character", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: `my\windows\path`, + Op: informer.Eq, + Partial: true, + }), + wantNames: []string{"backslash-character-matches"}, + }, + { + name: "not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.NotEq, + Partial: false, + }), + wantNames: []string{"partial-matches", "not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "partial not eq filter", + filters: orFiltersForFilters(informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.NotEq, + Partial: true, + }), + wantNames: []string{"not-matches-filter", "missing", "special-character-matches", "backslash-character-matches"}, + }, + { + name: "multiple or filters match", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "notequal", + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "not-matches-filter"}, + }, + { + name: "or filters on different fields", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.Eq, + Partial: true, + }, + informer.Filter{ + Field: []string{`metadata`, `name`}, + Match: "missing", + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{"matches-filter", "partial-matches", "missing"}, + }, + { + name: "and filters, both must match", + filters: []informer.OrFilter{ + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "somevalue", + Op: informer.Eq, + Partial: true, + }, + }, + }, + { + Filters: []informer.Filter{ + { + Field: []string{`metadata`, `name`}, + Match: "matches-filter", + Op: informer.Eq, + Partial: false, + }, + }, + }, + }, + wantNames: []string{"matches-filter"}, + }, + { + name: "no matches", + filters: orFiltersForFilters( + informer.Filter{ + Field: []string{`metadata`, `annotations[somekey]`}, + Match: "valueNotRepresented", + Op: informer.Eq, + Partial: false, + }, + ), + wantNames: []string{}, + }, + } + + for _, test := range tests { + test := test + i.Run(test.name, func() { + options := informer.ListOptions{ + Filters: test.filters, + } + partitions := []partition.Partition{defaultPartition} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + cfgMaps, total, continueToken, err := cache.ListByOptions(ctx, options, partitions, testNamespace) + i.Require().NoError(err) + // since there's no additional pages, the continue token should be empty + i.Require().Equal("", continueToken) + i.Require().NotNil(cfgMaps) + // assert instead of require so that we can see the full evaluation of # of resources returned + i.Assert().Equal(len(test.wantNames), total) + i.Assert().Len(cfgMaps.Items, len(test.wantNames)) + requireNames := sets.Set[string]{} + requireNames.Insert(test.wantNames...) + gotNames := sets.Set[string]{} + for _, configMap := range cfgMaps.Items { + gotNames.Insert(configMap.GetName()) + } + i.Require().True(requireNames.Equal(gotNames), "wanted %v, got %v", requireNames, gotNames) + }) + } +} + +func (i *IntegrationSuite) createCacheAndFactory(fields [][]string, transformFunc cache.TransformFunc) (*factory.Cache, *factory.CacheFactory, error) { + cacheFactory, err := factory.NewCacheFactory() + if err != nil { + return nil, nil, fmt.Errorf("unable to make factory: %w", err) + } + dynamicClient, err := dynamic.NewForConfig(&i.restCfg) + if err != nil { + return nil, nil, fmt.Errorf("unable to make dynamicClient: %w", err) + } + configMapGVK := schema.GroupVersionKind{ + Group: "", + Version: "v1", + Kind: "ConfigMap", + } + configMapGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "configmaps", + } + dynamicResource := dynamicClient.Resource(configMapGVR).Namespace(testNamespace) + cache, err := cacheFactory.CacheFor(fields, transformFunc, dynamicResource, configMapGVK, true, true) + if err != nil { + return nil, nil, fmt.Errorf("unable to make cache: %w", err) + } + return &cache, cacheFactory, nil +} + +func (i *IntegrationSuite) waitForCacheReady(readyResourceNames []string, namespace string, cache *factory.Cache) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + return wait.PollUntilContextCancel(ctx, time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { + var options informer.ListOptions + partitions := []partition.Partition{defaultPartition} + cacheCtx, cacheCancel := context.WithTimeout(ctx, time.Second*5) + defer cacheCancel() + currentResources, total, _, err := cache.ListByOptions(cacheCtx, options, partitions, namespace) + if err != nil { + // note that we don't return the error since that would stop the polling + return false, nil + } + if total != len(readyResourceNames) { + return false, nil + } + wantNames := sets.Set[string]{} + wantNames.Insert(readyResourceNames...) + gotNames := sets.Set[string]{} + for _, current := range currentResources.Items { + name := current.GetName() + if !wantNames.Has(name) { + return true, fmt.Errorf("got resource %s which wasn't expected", name) + } + gotNames.Insert(name) + } + return wantNames.Equal(gotNames), nil + }) +} + +func TestIntegrationSuite(t *testing.T) { + suite.Run(t, new(IntegrationSuite)) +} diff --git a/pkg/sqlcache/partition/partition.go b/pkg/sqlcache/partition/partition.go new file mode 100644 index 00000000..dd21a60e --- /dev/null +++ b/pkg/sqlcache/partition/partition.go @@ -0,0 +1,24 @@ +/* +Package partition represents listing parameters. They can be used to specify which namespaces a caller would like included +in a response, or which specific objects they are looking for. +*/ +package partition + +import ( + "k8s.io/apimachinery/pkg/util/sets" +) + +// Partition represents filtering of a request's results +type Partition struct { + // if true, do not apply any filtering, return all results. Overrides all other fields + Passthrough bool + + // if non-empty, only resources in the specified namespaces will be returned + Namespace string + + // if true, return all results, while still honoring Namespace. Overrides Names + All bool + + // if non-empty, only resources with matching names will be returned + Names sets.Set[string] +} diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go new file mode 100644 index 00000000..75f70b6e --- /dev/null +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// + +// Package store is a generated GoMock package. +package store + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of TXClient interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Cancel mocks base method. +func (m *MockTXClient) Cancel() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cancel") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cancel indicates an expected call of Cancel. +func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) +} + +// Commit mocks base method. +func (m *MockTXClient) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} + +// StmtExec mocks base method. +func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StmtExec indicates an expected call of StmtExec. +func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) +} diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go new file mode 100644 index 00000000..2c63b2de --- /dev/null +++ b/pkg/sqlcache/store/store.go @@ -0,0 +1,351 @@ +/* +Package store contains the sql backed store. It persists objects to a sqlite database. +*/ +package store + +import ( + "context" + "database/sql" + "fmt" + "reflect" + + "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" + "k8s.io/client-go/tools/cache" + + // needed for drivers + _ "modernc.org/sqlite" +) + +const ( + upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` + deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` + getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` + listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` + listKeysStmtFmt = `SELECT key FROM "%s"` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( + key TEXT UNIQUE NOT NULL PRIMARY KEY, + object BLOB, + objectnonce BLOB, + dekid INTEGER + )` +) + +// Store is a SQLite-backed cache.Store +type Store struct { + DBClient + + name string + typ reflect.Type + keyFunc cache.KeyFunc + shouldEncrypt bool + + upsertQuery string + deleteQuery string + getQuery string + listQuery string + listKeysQuery string + + upsertStmt *sql.Stmt + deleteStmt *sql.Stmt + getStmt *sql.Stmt + listStmt *sql.Stmt + listKeysStmt *sql.Stmt + + afterUpsert []func(key string, obj any, tx db.TXClient) error + afterDelete []func(key string, tx db.TXClient) error +} + +// Test that Store implements cache.Indexer +var _ cache.Store = (*Store)(nil) + +type DBClient interface { + BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) + Prepare(stmt string) *sql.Stmt + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows db.Rows) ([]string, error) + ReadInt(rows db.Rows) (int, error) + Upsert(tx db.TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error + CloseStmt(closable db.Closable) error +} + +// NewStore creates a SQLite-backed cache.Store for objects of the given example type +func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool, name string) (*Store, error) { + s := &Store{ + name: name, + typ: reflect.TypeOf(example), + DBClient: c, + keyFunc: keyFunc, + shouldEncrypt: shouldEncrypt, + afterUpsert: []func(key string, obj any, tx db.TXClient) error{}, + afterDelete: []func(key string, tx db.TXClient) error{}, + } + + // once multiple informerfactories are needed, this can accept the case where table already exists error is received + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return nil, err + } + createTableQuery := fmt.Sprintf(createTableFmt, db.Sanitize(s.name)) + err = txC.Exec(createTableQuery) + if err != nil { + return nil, &db.QueryError{QueryString: createTableQuery, Err: err} + } + + err = txC.Commit() + if err != nil { + return nil, err + } + + s.upsertQuery = fmt.Sprintf(upsertStmtFmt, db.Sanitize(s.name)) + s.deleteQuery = fmt.Sprintf(deleteStmtFmt, db.Sanitize(s.name)) + s.getQuery = fmt.Sprintf(getStmtFmt, db.Sanitize(s.name)) + s.listQuery = fmt.Sprintf(listStmtFmt, db.Sanitize(s.name)) + s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, db.Sanitize(s.name)) + + s.upsertStmt = s.Prepare(s.upsertQuery) + s.deleteStmt = s.Prepare(s.deleteQuery) + s.getStmt = s.Prepare(s.getQuery) + s.listStmt = s.Prepare(s.listQuery) + s.listKeysStmt = s.Prepare(s.listKeysQuery) + + return s, nil +} + +/* Core methods */ +// upsert saves an obj with its key, or updates key with obj if it exists in this Store +func (s *Store) upsert(key string, obj any) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } + + err = s.runAfterUpsert(key, obj, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// deleteByKey deletes the object associated with key, if it exists in this Store +func (s *Store) deleteByKey(key string) error { + tx, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + err = tx.StmtExec(tx.Stmt(s.deleteStmt), key) + if err != nil { + return &db.QueryError{QueryString: s.deleteQuery, Err: err} + } + + err = s.runAfterDelete(key, tx) + if err != nil { + return err + } + + return tx.Commit() +} + +// GetByKey returns the object associated with the given object's key +func (s *Store) GetByKey(key string) (item any, exists bool, err error) { + rows, err := s.QueryForRows(context.TODO(), s.getStmt, key) + if err != nil { + return nil, false, &db.QueryError{QueryString: s.getQuery, Err: err} + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + return nil, false, err + } + + if len(result) == 0 { + return nil, false, nil + } + + return result[0], true, nil +} + +/* Satisfy cache.Store */ + +// Add saves an obj, or updates it if it exists in this Store +func (s *Store) Add(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + + err = s.upsert(key, obj) + return err +} + +// Update saves an obj, or updates it if it exists in this Store +func (s *Store) Update(obj any) error { + return s.Add(obj) +} + +// Delete deletes the given object, if it exists in this Store +func (s *Store) Delete(obj any) error { + key, err := s.keyFunc(obj) + if err != nil { + return err + } + return s.deleteByKey(key) +} + +// List returns a list of all the currently known objects +// Note: I/O errors will panic this function, as the interface signature does not allow returning errors +func (s *Store) List() []any { + rows, err := s.QueryForRows(context.TODO(), s.listStmt) + if err != nil { + panic(&db.QueryError{QueryString: s.listQuery, Err: err}) + } + result, err := s.ReadObjects(rows, s.typ, s.shouldEncrypt) + if err != nil { + panic(fmt.Errorf("error in Store.List: %w", err)) + } + return result +} + +// ListKeys returns a list of all the keys currently in this Store +// Note: Atm it doesn't appear returning nil in the case of an error has any detrimental effects. An error is not +// uncommon enough nor does it appear to necessitate a panic. +func (s *Store) ListKeys() []string { + rows, err := s.QueryForRows(context.TODO(), s.listKeysStmt) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: while executing query: %s got error: %v", s.listKeysQuery, err) + return []string{} + } + result, err := s.ReadStrings(rows) + if err != nil { + fmt.Printf("Unexpected error in store.ListKeys: %v\n", err) + return []string{} + } + return result +} + +// Get returns the object with the same key as obj +func (s *Store) Get(obj any) (item any, exists bool, err error) { + key, err := s.keyFunc(obj) + if err != nil { + return nil, false, err + } + + return s.GetByKey(key) +} + +// Replace will delete the contents of the Store, using instead the given list +func (s *Store) Replace(objects []any, _ string) error { + objectMap := map[string]any{} + + for _, object := range objects { + key, err := s.keyFunc(object) + if err != nil { + return err + } + objectMap[key] = object + } + return s.replaceByKey(objectMap) +} + +// replaceByKey will delete the contents of the Store, using instead the given key to obj map +func (s *Store) replaceByKey(objects map[string]any) error { + txC, err := s.BeginTx(context.Background(), true) + if err != nil { + return err + } + + txCListKeys := txC.Stmt(s.listKeysStmt) + + rows, err := s.QueryForRows(context.TODO(), txCListKeys) + if err != nil { + return err + } + keys, err := s.ReadStrings(rows) + if err != nil { + return err + } + + for _, key := range keys { + err = txC.StmtExec(txC.Stmt(s.deleteStmt), key) + if err != nil { + return err + } + err = s.runAfterDelete(key, txC) + if err != nil { + return err + } + } + + for key, obj := range objects { + err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return err + } + err = s.runAfterUpsert(key, obj, txC) + if err != nil { + return err + } + } + + return txC.Commit() +} + +// Resync is a no-op and is deprecated +func (s *Store) Resync() error { + return nil +} + +/* Utilities */ + +// RegisterAfterUpsert registers a func to be called after each upsert +func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC db.TXClient) error) { + s.afterUpsert = append(s.afterUpsert, f) +} + +func (s *Store) GetName() string { + return s.name +} + +func (s *Store) GetShouldEncrypt() bool { + return s.shouldEncrypt +} + +func (s *Store) GetType() reflect.Type { + return s.typ +} + +// keep +// runAfterUpsert executes functions registered to run after upsert +func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error { + for _, f := range s.afterUpsert { + err := f(key, obj, txC) + if err != nil { + return err + } + } + return nil +} + +// RegisterAfterDelete registers a func to be called after each deletion +func (s *Store) RegisterAfterDelete(f func(key string, txC db.TXClient) error) { + s.afterDelete = append(s.afterDelete, f) +} + +// keep +// runAfterDelete executes functions registered to run after upsert +func (s *Store) runAfterDelete(key string, txC db.TXClient) error { + for _, f := range s.afterDelete { + err := f(key, txC) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/sqlcache/store/store_mocks_test.go b/pkg/sqlcache/store/store_mocks_test.go new file mode 100644 index 00000000..d30df82b --- /dev/null +++ b/pkg/sqlcache/store/store_mocks_test.go @@ -0,0 +1,165 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + db "github.com/rancher/steve/pkg/sqlcache/db" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockDBClient is a mock of DBClient interface. +type MockDBClient struct { + ctrl *gomock.Controller + recorder *MockDBClientMockRecorder +} + +// MockDBClientMockRecorder is the mock recorder for MockDBClient. +type MockDBClientMockRecorder struct { + mock *MockDBClient +} + +// NewMockDBClient creates a new mock instance. +func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { + mock := &MockDBClient{ctrl: ctrl} + mock.recorder = &MockDBClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { + return m.recorder +} + +// BeginTx mocks base method. +func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) + ret0, _ := ret[0].(db.TXClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTx indicates an expected call of BeginTx. +func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) +} + +// CloseStmt mocks base method. +func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) +} + +// Prepare mocks base method. +func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go new file mode 100644 index 00000000..43537816 --- /dev/null +++ b/pkg/sqlcache/store/store_test.go @@ -0,0 +1,649 @@ +/* +Copyright 2023 SUSE LLC + +Adapted from client-go, Copyright 2014 The Kubernetes Authors. +*/ + +package store + +// Mocks for this test are generated with the following command. +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + + "github.com/rancher/steve/pkg/sqlcache/db" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +const () +const TEST_DB_LOCATION = "./sqlstore.sqlite" + +func testStoreKeyFunc(obj interface{}) (string, error) { + return obj.(testStoreObject).Id, nil +} + +type testStoreObject struct { + Id string + Val string +} + +func TestAdd(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Add with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Add(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Commit().Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx db.TXClient) error { + count++ + return nil + }) + err := store.Add(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Add(testObject) + assert.NotNil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Add with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Add with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Add(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Update updates the given object in the accumulator associated with the given object's key +func TestUpdate(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Update with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Update(testObject) + assert.Nil(t, err) + // dbclient beginerr + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(nil) + + var count int + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + count++ + return nil + }) + err := store.Update(testObject) + assert.Nil(t, err) + assert.Equal(t, count, 1) + }, + }) + + tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + return fmt.Errorf("error") + }) + err := store.Update(testObject) + assert.NotNil(t, err) + }, + }) + + tests = append(tests, testCase{description: "Update with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + store := SetupStore(t, c, shouldEncrypt) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + tests = append(tests, testCase{description: "Update with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + + err := store.Update(testObject) + assert.NotNil(t, err) + }}) + + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Delete deletes the given object from the accumulator associated with the given object's key +func TestDelete(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + // Tests with shouldEncryptSet to false + tests = append(tests, testCase{description: "Delete with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(nil) + err := store.Delete(testObject) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Delete with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + // tx.EXPECT(). + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Commit().Return(fmt.Errorf("error")) + err := store.Delete(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// List returns a list of all the currently non-empty accumulators +func TestList(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + testObject := testStoreObject{Id: "something", Val: "a"} + + var tests []testCase + + tests = append(tests, testCase{description: "List with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + items := store.List() + assert.Len(t, items, 0) + }, + }) + tests = append(tests, testCase{description: "List with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + fakeItemsToReturn := []any{"something1", 2, false} + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(fakeItemsToReturn, nil) + items := store.List() + assert.Equal(t, fakeItemsToReturn, items) + }, + }) + tests = append(tests, testCase{description: "List with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listStmt).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + defer func() { + recover() + }() + _ = store.List() + assert.Fail(t, "Store list should panic when ReadObjects returns an error") + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// ListKeys returns a list of all the keys currently associated with non-empty accumulators +func TestListKeys(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + + tests = append(tests, testCase{description: "ListKeys with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{"a", "b", "c"}, nil) + keys := store.ListKeys() + assert.Len(t, keys, 3) + }, + }) + + tests = append(tests, testCase{description: "ListKeys with DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + keys := store.ListKeys() + assert.Len(t, keys, 0) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Get returns the accumulator associated with the given object's key +func TestGet(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Get with no DB Client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with no DB Client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.Get(testObject) + assert.Nil(t, err) + assert.Equal(t, item, nil) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "Get with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.Get(testObject) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// GetByKey returns the accumulator associated with the given key +func TestGetByKey(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{testObject}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, item, testObject) + assert.True(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return([]any{}, nil) + item, exists, err := store.GetByKey(testObject.Id) + assert.Nil(t, err) + assert.Equal(t, nil, item) + assert.False(t, exists) + }, + }) + tests = append(tests, testCase{description: "GetByKey with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().QueryForRows(context.TODO(), store.getStmt, testObject.Id).Return(r, nil) + c.EXPECT().ReadObjects(r, reflect.TypeOf(testObject), store.shouldEncrypt).Return(nil, fmt.Errorf("error")) + _, _, err := store.GetByKey(testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Replace will delete the contents of the store, using instead the +// given list. Store takes ownership of the list, you should not reference +// it after calling this function. +func TestReplace(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + testObject := testStoreObject{Id: "something", Val: "a"} + tests = append(tests, testCase{description: "Replace with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + txC.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + c, tx := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil) + tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{}, nil) + c.EXPECT().Upsert(tx, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) + tx.EXPECT().Commit() + err := store.Replace([]any{testObject}, testObject.Id) + assert.Nil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with no DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, tx := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil) + tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + c, tx := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(tx, nil) + tx.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + tests = append(tests, testCase{description: "Replace with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + c, txC := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + r := &sql.Rows{} + c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) + txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) + err := store.Replace([]any{testObject}, testObject.Id) + assert.NotNil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +// Resync is meaningless in the terms appearing here but has +// meaning in some implementations that have non-trivial +// additional behavior (e.g., DeltaFIFO). +func TestResync(t *testing.T) { + type testCase struct { + description string + test func(t *testing.T, shouldEncrypt bool) + } + + var tests []testCase + tests = append(tests, testCase{description: "Resync shouldn't call the client, panic, or do anything else", test: func(t *testing.T, shouldEncrypt bool) { + c, _ := SetupMockDB(t) + store := SetupStore(t, c, shouldEncrypt) + err := store.Resync() + assert.Nil(t, err) + }, + }) + t.Parallel() + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { test.test(t, false) }) + t.Run(fmt.Sprintf("%s with encryption", test.description), func(t *testing.T) { test.test(t, true) }) + } +} + +func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) { + dbC := NewMockDBClient(gomock.NewController(t)) // add functionality once store expectation are known + txC := NewMockTXClient(gomock.NewController(t)) + // stmt := NewMockStmt(gomock.NewController()) + txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil) + txC.EXPECT().Commit().Return(nil) + dbC.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + + // use stmt mock here + dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + + return dbC, txC +} +func SetupStore(t *testing.T, client *MockDBClient, shouldEncrypt bool) *Store { + store, err := NewStore(testStoreObject{}, testStoreKeyFunc, client, shouldEncrypt, "testStoreObject") + if err != nil { + t.Error(err) + } + return store +} diff --git a/pkg/sqlcache/store/tx_mocks_test.go b/pkg/sqlcache/store/tx_mocks_test.go new file mode 100644 index 00000000..0c05ab7f --- /dev/null +++ b/pkg/sqlcache/store/tx_mocks_test.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt +// + +// Package store is a generated GoMock package. +package store + +import ( + context "context" + sql "database/sql" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockStmt is a mock of Stmt interface. +type MockStmt struct { + ctrl *gomock.Controller + recorder *MockStmtMockRecorder +} + +// MockStmtMockRecorder is the mock recorder for MockStmt. +type MockStmtMockRecorder struct { + mock *MockStmt +} + +// NewMockStmt creates a new mock instance. +func NewMockStmt(ctrl *gomock.Controller) *MockStmt { + mock := &MockStmt{ctrl: ctrl} + mock.recorder = &MockStmtMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStmt) EXPECT() *MockStmtMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) +} + +// Query mocks base method. +func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) +} + +// QueryContext mocks base method. +func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryContext", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryContext indicates an expected call of QueryContext. +func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) +} diff --git a/pkg/stores/sqlpartition/listprocessor/processor.go b/pkg/stores/sqlpartition/listprocessor/processor.go index 0ea15788..d9f5e02e 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor.go +++ b/pkg/stores/sqlpartition/listprocessor/processor.go @@ -10,8 +10,8 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/schemas/validation" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/stores/sqlpartition/listprocessor/processor_test.go b/pkg/stores/sqlpartition/listprocessor/processor_test.go index 08b6ea6a..80cfb0b0 100644 --- a/pkg/stores/sqlpartition/listprocessor/processor_test.go +++ b/pkg/stores/sqlpartition/listprocessor/processor_test.go @@ -8,8 +8,8 @@ import ( "testing" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go index 693261a7..06598043 100644 --- a/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go +++ b/pkg/stores/sqlpartition/listprocessor/proxy_mocks_test.go @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/pkg/stores/sqlpartition/partition_mocks_test.go b/pkg/stores/sqlpartition/partition_mocks_test.go index ea72b8c1..687b9006 100644 --- a/pkg/stores/sqlpartition/partition_mocks_test.go +++ b/pkg/stores/sqlpartition/partition_mocks_test.go @@ -13,7 +13,7 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" watch "k8s.io/apimachinery/pkg/watch" diff --git a/pkg/stores/sqlpartition/partitioner.go b/pkg/stores/sqlpartition/partitioner.go index ffee38df..b3b74f9b 100644 --- a/pkg/stores/sqlpartition/partitioner.go +++ b/pkg/stores/sqlpartition/partitioner.go @@ -5,9 +5,9 @@ import ( "sort" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/attributes" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/kv" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlpartition/partitioner_test.go b/pkg/stores/sqlpartition/partitioner_test.go index caba1897..cdc93ff1 100644 --- a/pkg/stores/sqlpartition/partitioner_test.go +++ b/pkg/stores/sqlpartition/partitioner_test.go @@ -6,7 +6,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/stretchr/testify/assert" diff --git a/pkg/stores/sqlpartition/store.go b/pkg/stores/sqlpartition/store.go index 145d5e72..f4ebb325 100644 --- a/pkg/stores/sqlpartition/store.go +++ b/pkg/stores/sqlpartition/store.go @@ -7,14 +7,14 @@ import ( "context" "github.com/rancher/apiserver/pkg/types" - lassopartition "github.com/rancher/lasso/pkg/cache/sql/partition" "github.com/rancher/steve/pkg/accesscontrol" + cachepartition "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/stores/partition" ) // Partitioner is an interface for interacting with partitions. type Partitioner interface { - All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]lassopartition.Partition, error) + All(apiOp *types.APIRequest, schema *types.APISchema, verb, id string) ([]cachepartition.Partition, error) Store() UnstructuredStore } diff --git a/pkg/stores/sqlpartition/store_test.go b/pkg/stores/sqlpartition/store_test.go index d1f11be7..d98a8188 100644 --- a/pkg/stores/sqlpartition/store_test.go +++ b/pkg/stores/sqlpartition/store_test.go @@ -15,7 +15,7 @@ import ( "go.uber.org/mock/gomock" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/accesscontrol" "github.com/rancher/steve/pkg/stores/sqlproxy" "github.com/rancher/wrangler/v3/pkg/generic" diff --git a/pkg/stores/sqlproxy/proxy_mocks_test.go b/pkg/stores/sqlproxy/proxy_mocks_test.go index a5559d47..45ccd418 100644 --- a/pkg/stores/sqlproxy/proxy_mocks_test.go +++ b/pkg/stores/sqlproxy/proxy_mocks_test.go @@ -14,9 +14,9 @@ import ( reflect "reflect" types "github.com/rancher/apiserver/pkg/types" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - factory "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + factory "github.com/rancher/steve/pkg/sqlcache/informer/factory" + partition "github.com/rancher/steve/pkg/sqlcache/partition" summary "github.com/rancher/wrangler/v3/pkg/summary" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index e2bb17d2..0870f04f 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -31,9 +31,9 @@ import ( "github.com/rancher/apiserver/pkg/apierror" "github.com/rancher/apiserver/pkg/types" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/wrangler/v3/pkg/data" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/rancher/wrangler/v3/pkg/schemas/validation" @@ -333,7 +333,7 @@ func gvkKey(group, version, kind string) string { return group + "_" + version + "_" + kind } -// getFieldsFromSchema converts object field names from types.APISchema's format into lasso's +// getFieldsFromSchema converts object field names from types.APISchema's format into steve's // cache.sql.informer's slice format (e.g. "metadata.resourceVersion" is ["metadata", "resourceVersion"]) func getFieldsFromSchema(schema *types.APISchema) [][]string { var fields [][]string @@ -757,7 +757,7 @@ func (s *Store) ListByPartitions(apiOp *types.APIRequest, schema *types.APISchem list, total, continueToken, err := inf.ListByOptions(apiOp.Context(), opts, partitions, apiOp.Namespace) if err != nil { - if errors.Is(err, informer.InvalidColumnErr) { + if errors.Is(err, informer.ErrInvalidColumn) { return nil, 0, "", apierror.NewAPIError(validation.InvalidBodyContent, err.Error()) } return nil, 0, "", err diff --git a/pkg/stores/sqlproxy/proxy_store_test.go b/pkg/stores/sqlproxy/proxy_store_test.go index bdbbbaab..a621e79d 100644 --- a/pkg/stores/sqlproxy/proxy_store_test.go +++ b/pkg/stores/sqlproxy/proxy_store_test.go @@ -12,11 +12,11 @@ import ( "github.com/rancher/wrangler/v3/pkg/schemas/validation" apierrors "k8s.io/apimachinery/pkg/api/errors" - "github.com/rancher/lasso/pkg/cache/sql/informer" - "github.com/rancher/lasso/pkg/cache/sql/informer/factory" - "github.com/rancher/lasso/pkg/cache/sql/partition" "github.com/rancher/steve/pkg/attributes" "github.com/rancher/steve/pkg/resources/common" + "github.com/rancher/steve/pkg/sqlcache/informer" + "github.com/rancher/steve/pkg/sqlcache/informer/factory" + "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/stores/sqlpartition/listprocessor" "github.com/rancher/steve/pkg/stores/sqlproxy/tablelistconvert" "go.uber.org/mock/gomock" @@ -42,7 +42,7 @@ import ( ) //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./proxy_mocks_test.go github.com/rancher/steve/pkg/stores/sqlproxy Cache,ClientGetter,CacheFactory,SchemaColumnSetter,RelationshipNotifier,TransformBuilder -//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +//go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister //go:generate mockgen --build_flags=--mod=mod -package sqlproxy -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface var c *watch.FakeWatcher diff --git a/pkg/stores/sqlproxy/sql_informer_mocks_test.go b/pkg/stores/sqlproxy/sql_informer_mocks_test.go index e8f5358c..125f2192 100644 --- a/pkg/stores/sqlproxy/sql_informer_mocks_test.go +++ b/pkg/stores/sqlproxy/sql_informer_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/lasso/pkg/cache/sql/informer (interfaces: ByOptionsLister) +// Source: github.com/rancher/steve/pkg/sqlcache/informer (interfaces: ByOptionsLister) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/lasso/pkg/cache/sql/informer ByOptionsLister +// mockgen --build_flags=--mod=mod -package sqlproxy -destination ./sql_informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister // // Package sqlproxy is a generated GoMock package. @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - informer "github.com/rancher/lasso/pkg/cache/sql/informer" - partition "github.com/rancher/lasso/pkg/cache/sql/partition" + informer "github.com/rancher/steve/pkg/sqlcache/informer" + partition "github.com/rancher/steve/pkg/sqlcache/partition" gomock "go.uber.org/mock/gomock" unstructured "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) diff --git a/scripts/test.sh b/scripts/test.sh index f9143f72..14f7c8b7 100644 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,3 +1,13 @@ #!/bin/bash -go test ./... \ No newline at end of file +if ! command -v setup-envtest; then + echo "setup-envtest is required for tests, but was not installed" + echo "see the 'Running Tests' section of the readme for install instructions" + exit 127 +fi + +minor=$(go mod graph | grep ' k8s.io/client-go@' | head -n1 | cut -d@ -f2 | cut -d '.' -f 2) +version="1.$minor.x" + +export KUBEBUILDER_ASSETS=$(setup-envtest use -p path "$version") +go test ./...