diff --git a/agent/main.go b/agent/main.go index 1040bf888..b2598f2df 100644 --- a/agent/main.go +++ b/agent/main.go @@ -199,7 +199,7 @@ func runInHarReaderMode() { func enableExpFeatureIfNeeded() { if config.Config.OAS { oasGenerator := dependency.GetInstance(dependency.OasGeneratorDependency).(oas.OasGenerator) - oasGenerator.Start() + oasGenerator.Start(nil) } if config.Config.ServiceMap { serviceMapGenerator := dependency.GetInstance(dependency.ServiceMapGeneratorDependency).(servicemap.ServiceMap) @@ -371,7 +371,7 @@ func handleIncomingMessageAsTapper(socketConnection *websocket.Conn) { func initializeDependencies() { dependency.RegisterGenerator(dependency.ServiceMapGeneratorDependency, func() interface{} { return servicemap.GetDefaultServiceMapInstance() }) - dependency.RegisterGenerator(dependency.OasGeneratorDependency, func() interface{} { return oas.GetDefaultOasGeneratorInstance(nil) }) + dependency.RegisterGenerator(dependency.OasGeneratorDependency, func() interface{} { return oas.GetDefaultOasGeneratorInstance() }) dependency.RegisterGenerator(dependency.EntriesProvider, func() interface{} { return &entries.BasenineEntriesProvider{} }) dependency.RegisterGenerator(dependency.EntriesSocketStreamer, func() interface{} { return &api.BasenineEntryStreamer{} }) dependency.RegisterGenerator(dependency.EntryStreamerSocketConnector, func() interface{} { return &api.DefaultEntryStreamerSocketConnector{} }) diff --git a/agent/pkg/controllers/oas_controller_test.go b/agent/pkg/controllers/oas_controller_test.go index 381927b6e..419956951 100644 --- a/agent/pkg/controllers/oas_controller_test.go +++ b/agent/pkg/controllers/oas_controller_test.go @@ -58,12 +58,12 @@ func getRecorderAndContext() (*httptest.ResponseRecorder, *gin.Context) { receiveBuffer: bytes.NewBufferString("\n"), } dependency.RegisterGenerator(dependency.OasGeneratorDependency, func() interface{} { - return oas.GetDefaultOasGeneratorInstance(dummyConn) + return oas.GetDefaultOasGeneratorInstance() }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) - oas.GetDefaultOasGeneratorInstance(dummyConn).Start() - oas.GetDefaultOasGeneratorInstance(dummyConn).GetServiceSpecs().Store("some", oas.NewGen("some")) + oas.GetDefaultOasGeneratorInstance().Start(dummyConn) + oas.GetDefaultOasGeneratorInstance().GetServiceSpecs().Store("some", oas.NewGen("some")) return recorder, c } diff --git a/agent/pkg/oas/oas_generator.go b/agent/pkg/oas/oas_generator.go index 562eb8355..96db713cd 100644 --- a/agent/pkg/oas/oas_generator.go +++ b/agent/pkg/oas/oas_generator.go @@ -19,7 +19,7 @@ var ( ) type OasGenerator interface { - Start() + Start(conn *basenine.Connection) Stop() IsStarted() bool GetServiceSpecs() *sync.Map @@ -35,23 +35,41 @@ type defaultOasGenerator struct { entriesQuery string } -func GetDefaultOasGeneratorInstance(conn *basenine.Connection) *defaultOasGenerator { +func GetDefaultOasGeneratorInstance() *defaultOasGenerator { syncOnce.Do(func() { - instance = NewDefaultOasGenerator(conn) + instance = NewDefaultOasGenerator() logger.Log.Debug("OAS Generator Initialized") }) return instance } -func (g *defaultOasGenerator) Start() { +func (g *defaultOasGenerator) Start(conn *basenine.Connection) { if g.started { return } + + if g.dbConn == nil { + if conn == nil { + logger.Log.Infof("Creating new DB connection for OAS generator to address %s:%s", shared.BasenineHost, shared.BaseninePort) + newConn, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) + if err != nil { + logger.Log.Error("Error connecting to DB for OAS generator, err: %v", err) + return + } + + conn = newConn + } + + g.dbConn = conn + } + ctx, cancel := context.WithCancel(context.Background()) g.cancel = cancel g.ctx = ctx g.serviceSpecs = &sync.Map{} + g.started = true + go g.runGenerator() } @@ -59,8 +77,15 @@ func (g *defaultOasGenerator) Stop() { if !g.started { return } + + if g.dbConn != nil { + g.dbConn.Close() + g.dbConn = nil + } + g.cancel() g.reset() + g.started = false } @@ -69,7 +94,7 @@ func (g *defaultOasGenerator) IsStarted() bool { } func (g *defaultOasGenerator) runGenerator() { - // Make []byte channels to recieve the data and the meta + // Make []byte channels to receive the data and the meta dataChan := make(chan []byte) metaChan := make(chan []byte) @@ -80,6 +105,8 @@ func (g *defaultOasGenerator) runGenerator() { select { case <-g.ctx.Done(): logger.Log.Infof("OAS Generator was canceled") + close(dataChan) + close(metaChan) return case metaBytes, ok := <-metaChan: @@ -181,21 +208,12 @@ func (g *defaultOasGenerator) SetEntriesQuery(query string) bool { return changed } -func NewDefaultOasGenerator(conn *basenine.Connection) *defaultOasGenerator { - if conn == nil { - logger.Log.Infof("Creating new DB connection for OAS generator to address %s:%s", shared.BasenineHost, shared.BaseninePort) - newConn, err := basenine.NewConnection(shared.BasenineHost, shared.BaseninePort) - if err != nil { - panic(err) - } - conn = newConn - } - +func NewDefaultOasGenerator() *defaultOasGenerator { return &defaultOasGenerator{ started: false, ctx: nil, cancel: nil, serviceSpecs: nil, - dbConn: conn, + dbConn: nil, } } diff --git a/agent/pkg/oas/oas_generator_test.go b/agent/pkg/oas/oas_generator_test.go index b2c44b412..c7255bf9d 100644 --- a/agent/pkg/oas/oas_generator_test.go +++ b/agent/pkg/oas/oas_generator_test.go @@ -3,14 +3,11 @@ package oas import ( "encoding/json" "github.com/up9inc/mizu/agent/pkg/har" - "sync" "testing" ) func TestOASGen(t *testing.T) { gen := new(defaultOasGenerator) - gen.dbConn = GetFakeDBConn(`{"startedDateTime": "20000101","request": {"url": "https://host/path", "method": "GET"}, "response": {"status": 200}}`) - gen.serviceSpecs = &sync.Map{} e := new(har.Entry) err := json.Unmarshal([]byte(`{"startedDateTime": "20000101","request": {"url": "https://host/path", "method": "GET"}, "response": {"status": 200}}`), e) @@ -22,7 +19,9 @@ func TestOASGen(t *testing.T) { Destination: "some", Entry: *e, } - gen.Start() + + dummyConn := GetFakeDBConn(`{"startedDateTime": "20000101","request": {"url": "https://host/path", "method": "GET"}, "response": {"status": 200}}`) + gen.Start(dummyConn) gen.handleHARWithSource(ews) g, ok := gen.serviceSpecs.Load("some") if !ok { diff --git a/agent/pkg/oas/specgen_test.go b/agent/pkg/oas/specgen_test.go index da9a7f259..0eca3ab49 100644 --- a/agent/pkg/oas/specgen_test.go +++ b/agent/pkg/oas/specgen_test.go @@ -61,8 +61,7 @@ func TestEntries(t *testing.T) { t.FailNow() } - dummyConn := GetFakeDBConn("\n") - gen := NewDefaultOasGenerator(dummyConn) + gen := NewDefaultOasGenerator() gen.serviceSpecs = new(sync.Map) loadStartingOAS("test_artifacts/catalogue.json", "catalogue", gen.serviceSpecs) loadStartingOAS("test_artifacts/trcc.json", "trcc-api-service", gen.serviceSpecs) @@ -136,8 +135,7 @@ func TestEntries(t *testing.T) { } func TestFileSingle(t *testing.T) { - dummyConn := GetFakeDBConn("\n") - gen := NewDefaultOasGenerator(dummyConn) + gen := NewDefaultOasGenerator() gen.serviceSpecs = new(sync.Map) // loadStartingOAS() file := "test_artifacts/params.har" @@ -227,8 +225,7 @@ func loadStartingOAS(file string, label string, specs *sync.Map) { } func TestEntriesNegative(t *testing.T) { - dummyConn := GetFakeDBConn("\n") - gen := NewDefaultOasGenerator(dummyConn) + gen := NewDefaultOasGenerator() gen.serviceSpecs = new(sync.Map) files := []string{"invalid"} _, err := feedEntries(files, false, gen) @@ -239,8 +236,7 @@ func TestEntriesNegative(t *testing.T) { } func TestEntriesPositive(t *testing.T) { - dummyConn := GetFakeDBConn("\n") - gen := NewDefaultOasGenerator(dummyConn) + gen := NewDefaultOasGenerator() gen.serviceSpecs = new(sync.Map) files := []string{"test_artifacts/params.har"} _, err := feedEntries(files, false, gen)