Move loadExtensions into main.go and map extensions into extensionsMap

This commit is contained in:
M. Mert Yildiran 2021-08-20 18:30:21 +03:00
parent 461bcf9f24
commit f97e7c4793
No known key found for this signature in database
GPG Key ID: D42ADB236521BF7A
5 changed files with 72 additions and 65 deletions

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"mizuserver/pkg/api"
"mizuserver/pkg/models"
@ -12,6 +13,9 @@ import (
"net/http"
"os"
"os/signal"
"path"
"path/filepath"
"plugin"
"github.com/gin-contrib/static"
"github.com/gin-gonic/gin"
@ -28,8 +32,14 @@ var standaloneMode = flag.Bool("standalone", false, "Run in standalone tapper an
var apiServerAddress = flag.String("api-server-address", "", "Address of mizu API server")
var namespace = flag.String("namespace", "", "Resolve IPs if they belong to resources in this namespace (default is all)")
var extensions []*tapApi.Extension // global
var extensionsMap map[string]*tapApi.Extension // global
var allOutboundPorts []string // global
var allInboundPorts []string // global
func main() {
flag.Parse()
loadExtensions()
hostMode := os.Getenv(shared.HostModeEnvVar) == "1"
tapOpts := &tap.TapOpts{HostMode: hostMode}
@ -41,10 +51,10 @@ func main() {
api.StartResolving(*namespace)
filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel, extensions)
// go filterHarItems(harOutputChannel, filteredOutputItemsChannel, getTrafficFilteringOptions())
go api.StartReadingEntries(filteredOutputItemsChannel, nil)
go api.StartReadingEntries(filteredOutputItemsChannel, nil, extensionsMap)
// go api.StartReadingOutbound(outboundLinkOutputChannel)
hostApi(nil)
@ -61,7 +71,7 @@ func main() {
// harOutputChannel, outboundLinkOutputChannel := tap.StartPassiveTapper(tapOpts)
filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel, extensions)
socketConnection, err := shared.ConnectToSocketServer(*apiServerAddress, shared.DEFAULT_SOCKET_RETRIES, shared.DEFAULT_SOCKET_RETRY_SLEEP_TIME, false)
if err != nil {
panic(fmt.Sprintf("Error connecting to socket server at %s %v", *apiServerAddress, err))
@ -77,7 +87,7 @@ func main() {
// filteredHarChannel := make(chan *tapApi.OutputChannelItem)
// go filterHarItems(socketHarOutChannel, filteredHarChannel, getTrafficFilteringOptions())
go api.StartReadingEntries(socketHarOutChannel, nil)
go api.StartReadingEntries(socketHarOutChannel, nil, extensionsMap)
hostApi(socketHarOutChannel)
}
@ -89,6 +99,55 @@ func main() {
rlog.Info("Exiting")
}
func mergeUnique(slice []string, merge []string) []string {
for _, i := range merge {
add := true
for _, ele := range slice {
if ele == i {
add = false
}
}
if add {
slice = append(slice, i)
}
}
return slice
}
func loadExtensions() {
dir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
extensionsDir := path.Join(dir, "./extensions/")
files, err := ioutil.ReadDir(extensionsDir)
if err != nil {
log.Fatal(err)
}
extensions = make([]*tapApi.Extension, len(files))
extensionsMap = make(map[string]*tapApi.Extension)
for i, file := range files {
filename := file.Name()
log.Printf("Loading extension: %s\n", filename)
extension := &tapApi.Extension{
Path: path.Join(extensionsDir, filename),
}
plug, _ := plugin.Open(extension.Path)
extension.Plug = plug
symDissector, _ := plug.Lookup("Dissector")
var dissector tapApi.Dissector
dissector, _ = symDissector.(tapApi.Dissector)
dissector.Register(extension)
extension.Dissector = dissector
log.Printf("Extension Properties: %+v\n", extension)
extensions[i] = extension
extensionsMap[extension.Name] = extension
allOutboundPorts = mergeUnique(allOutboundPorts, extension.OutboundPorts)
allInboundPorts = mergeUnique(allInboundPorts, extension.InboundPorts)
}
log.Printf("allOutboundPorts: %v\n", allOutboundPorts)
log.Printf("allInboundPorts: %v\n", allInboundPorts)
}
func hostApi(socketHarOutputChannel chan<- *tapApi.OutputChannelItem) {
app := gin.Default()

View File

@ -51,11 +51,11 @@ func StartResolving(namespace string) {
holder.SetResolver(res)
}
func StartReadingEntries(harChannel <-chan *tapApi.OutputChannelItem, workingDir *string) {
func StartReadingEntries(harChannel <-chan *tapApi.OutputChannelItem, workingDir *string, extensionsMap map[string]*tapApi.Extension) {
if workingDir != nil && *workingDir != "" {
startReadingFiles(*workingDir)
} else {
startReadingChannel(harChannel)
startReadingChannel(harChannel, extensionsMap)
}
}
@ -105,13 +105,15 @@ func startReadingFiles(workingDir string) {
}
}
func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem) {
func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem, extensionsMap map[string]*tapApi.Extension) {
if outputItems == nil {
panic("Channel of captured messages is nil")
}
for item := range outputItems {
fmt.Printf("item: %+v\n", item)
extension := extensionsMap[item.Protocol]
fmt.Printf("extension: %+v\n", extension)
var req *http.Request
marshedReq, _ := json.Marshal(item.Data.Request.Orig)
json.Unmarshal(marshedReq, &req)

View File

@ -44,7 +44,7 @@ type RequestResponsePair struct {
}
type OutputChannelItem struct {
Type string
Protocol string
Timestamp int64
ConnectionInfo *ConnectionInfo
Data *RequestResponsePair

View File

@ -86,7 +86,7 @@ func (matcher *requestResponseMatcher) registerResponse(ident string, response *
func (matcher *requestResponseMatcher) preparePair(requestHTTPMessage *api.GenericMessage, responseHTTPMessage *api.GenericMessage) *api.OutputChannelItem {
return &api.OutputChannelItem{
Type: ExtensionName,
Protocol: ExtensionName,
Timestamp: time.Now().UnixNano() / int64(time.Millisecond),
ConnectionInfo: nil,
Data: &api.RequestResponsePair{

View File

@ -12,13 +12,9 @@ import (
"encoding/hex"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"os/signal"
"path"
"path/filepath"
"plugin"
"runtime"
"runtime/pprof"
"strconv"
@ -127,8 +123,6 @@ var nErrors uint
var ownIps []string // global
var hostMode bool // global
var extensions []*api.Extension // global
var allOutboundPorts []string // global
var allInboundPorts []string // global
/* minOutputLevel: Error will be printed only if outputLevel is above this value
* t: key for errorsMap (counting errors)
@ -192,8 +186,9 @@ func (c *Context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo
}
func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) {
func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, extensionsRef []*api.Extension) {
hostMode = opts.HostMode
extensions = extensionsRef
if GetMemoryProfilingEnabled() {
startMemoryProfiler()
@ -233,56 +228,7 @@ func startMemoryProfiler() {
}()
}
func MergeUnique(slice []string, merge []string) []string {
for _, i := range merge {
add := true
for _, ele := range slice {
if ele == i {
add = false
}
}
if add {
slice = append(slice, i)
}
}
return slice
}
func loadExtensions() {
dir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
extensionsDir := path.Join(dir, "./extensions/")
files, err := ioutil.ReadDir(extensionsDir)
if err != nil {
log.Fatal(err)
}
extensions = make([]*api.Extension, len(files))
for i, file := range files {
filename := file.Name()
log.Printf("Loading extension: %s\n", filename)
extension := &api.Extension{
Path: path.Join(extensionsDir, filename),
}
plug, _ := plugin.Open(extension.Path)
extension.Plug = plug
symDissector, _ := plug.Lookup("Dissector")
var dissector api.Dissector
dissector, _ = symDissector.(api.Dissector)
dissector.Register(extension)
extension.Dissector = dissector
log.Printf("Extension Properties: %+v\n", extension)
extensions[i] = extension
allOutboundPorts = MergeUnique(allOutboundPorts, extension.OutboundPorts)
allInboundPorts = MergeUnique(allInboundPorts, extension.InboundPorts)
}
log.Printf("allOutboundPorts: %v\n", allOutboundPorts)
log.Printf("allInboundPorts: %v\n", allInboundPorts)
}
func startPassiveTapper(outputItems chan *api.OutputChannelItem) {
loadExtensions()
log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile)
defer util.Run()()