Move loadExtensions into main.go and map extensions into extensionsMap

This commit is contained in:
M. Mert Yildiran 2021-08-20 18:30:21 +03:00
parent 461bcf9f24
commit f97e7c4793
No known key found for this signature in database
GPG Key ID: D42ADB236521BF7A
5 changed files with 72 additions and 65 deletions

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"mizuserver/pkg/api" "mizuserver/pkg/api"
"mizuserver/pkg/models" "mizuserver/pkg/models"
@ -12,6 +13,9 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"path"
"path/filepath"
"plugin"
"github.com/gin-contrib/static" "github.com/gin-contrib/static"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -28,8 +32,14 @@ var standaloneMode = flag.Bool("standalone", false, "Run in standalone tapper an
var apiServerAddress = flag.String("api-server-address", "", "Address of mizu API server") var apiServerAddress = flag.String("api-server-address", "", "Address of mizu API server")
var namespace = flag.String("namespace", "", "Resolve IPs if they belong to resources in this namespace (default is all)") var namespace = flag.String("namespace", "", "Resolve IPs if they belong to resources in this namespace (default is all)")
var extensions []*tapApi.Extension // global
var extensionsMap map[string]*tapApi.Extension // global
var allOutboundPorts []string // global
var allInboundPorts []string // global
func main() { func main() {
flag.Parse() flag.Parse()
loadExtensions()
hostMode := os.Getenv(shared.HostModeEnvVar) == "1" hostMode := os.Getenv(shared.HostModeEnvVar) == "1"
tapOpts := &tap.TapOpts{HostMode: hostMode} tapOpts := &tap.TapOpts{HostMode: hostMode}
@ -41,10 +51,10 @@ func main() {
api.StartResolving(*namespace) api.StartResolving(*namespace)
filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem) filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel) tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel, extensions)
// go filterHarItems(harOutputChannel, filteredOutputItemsChannel, getTrafficFilteringOptions()) // go filterHarItems(harOutputChannel, filteredOutputItemsChannel, getTrafficFilteringOptions())
go api.StartReadingEntries(filteredOutputItemsChannel, nil) go api.StartReadingEntries(filteredOutputItemsChannel, nil, extensionsMap)
// go api.StartReadingOutbound(outboundLinkOutputChannel) // go api.StartReadingOutbound(outboundLinkOutputChannel)
hostApi(nil) hostApi(nil)
@ -61,7 +71,7 @@ func main() {
// harOutputChannel, outboundLinkOutputChannel := tap.StartPassiveTapper(tapOpts) // harOutputChannel, outboundLinkOutputChannel := tap.StartPassiveTapper(tapOpts)
filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem) filteredOutputItemsChannel := make(chan *tapApi.OutputChannelItem)
tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel) tap.StartPassiveTapper(tapOpts, filteredOutputItemsChannel, extensions)
socketConnection, err := shared.ConnectToSocketServer(*apiServerAddress, shared.DEFAULT_SOCKET_RETRIES, shared.DEFAULT_SOCKET_RETRY_SLEEP_TIME, false) socketConnection, err := shared.ConnectToSocketServer(*apiServerAddress, shared.DEFAULT_SOCKET_RETRIES, shared.DEFAULT_SOCKET_RETRY_SLEEP_TIME, false)
if err != nil { if err != nil {
panic(fmt.Sprintf("Error connecting to socket server at %s %v", *apiServerAddress, err)) panic(fmt.Sprintf("Error connecting to socket server at %s %v", *apiServerAddress, err))
@ -77,7 +87,7 @@ func main() {
// filteredHarChannel := make(chan *tapApi.OutputChannelItem) // filteredHarChannel := make(chan *tapApi.OutputChannelItem)
// go filterHarItems(socketHarOutChannel, filteredHarChannel, getTrafficFilteringOptions()) // go filterHarItems(socketHarOutChannel, filteredHarChannel, getTrafficFilteringOptions())
go api.StartReadingEntries(socketHarOutChannel, nil) go api.StartReadingEntries(socketHarOutChannel, nil, extensionsMap)
hostApi(socketHarOutChannel) hostApi(socketHarOutChannel)
} }
@ -89,6 +99,55 @@ func main() {
rlog.Info("Exiting") rlog.Info("Exiting")
} }
func mergeUnique(slice []string, merge []string) []string {
for _, i := range merge {
add := true
for _, ele := range slice {
if ele == i {
add = false
}
}
if add {
slice = append(slice, i)
}
}
return slice
}
func loadExtensions() {
dir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
extensionsDir := path.Join(dir, "./extensions/")
files, err := ioutil.ReadDir(extensionsDir)
if err != nil {
log.Fatal(err)
}
extensions = make([]*tapApi.Extension, len(files))
extensionsMap = make(map[string]*tapApi.Extension)
for i, file := range files {
filename := file.Name()
log.Printf("Loading extension: %s\n", filename)
extension := &tapApi.Extension{
Path: path.Join(extensionsDir, filename),
}
plug, _ := plugin.Open(extension.Path)
extension.Plug = plug
symDissector, _ := plug.Lookup("Dissector")
var dissector tapApi.Dissector
dissector, _ = symDissector.(tapApi.Dissector)
dissector.Register(extension)
extension.Dissector = dissector
log.Printf("Extension Properties: %+v\n", extension)
extensions[i] = extension
extensionsMap[extension.Name] = extension
allOutboundPorts = mergeUnique(allOutboundPorts, extension.OutboundPorts)
allInboundPorts = mergeUnique(allInboundPorts, extension.InboundPorts)
}
log.Printf("allOutboundPorts: %v\n", allOutboundPorts)
log.Printf("allInboundPorts: %v\n", allInboundPorts)
}
func hostApi(socketHarOutputChannel chan<- *tapApi.OutputChannelItem) { func hostApi(socketHarOutputChannel chan<- *tapApi.OutputChannelItem) {
app := gin.Default() app := gin.Default()

View File

@ -51,11 +51,11 @@ func StartResolving(namespace string) {
holder.SetResolver(res) holder.SetResolver(res)
} }
func StartReadingEntries(harChannel <-chan *tapApi.OutputChannelItem, workingDir *string) { func StartReadingEntries(harChannel <-chan *tapApi.OutputChannelItem, workingDir *string, extensionsMap map[string]*tapApi.Extension) {
if workingDir != nil && *workingDir != "" { if workingDir != nil && *workingDir != "" {
startReadingFiles(*workingDir) startReadingFiles(*workingDir)
} else { } else {
startReadingChannel(harChannel) startReadingChannel(harChannel, extensionsMap)
} }
} }
@ -105,13 +105,15 @@ func startReadingFiles(workingDir string) {
} }
} }
func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem) { func startReadingChannel(outputItems <-chan *tapApi.OutputChannelItem, extensionsMap map[string]*tapApi.Extension) {
if outputItems == nil { if outputItems == nil {
panic("Channel of captured messages is nil") panic("Channel of captured messages is nil")
} }
for item := range outputItems { for item := range outputItems {
fmt.Printf("item: %+v\n", item) fmt.Printf("item: %+v\n", item)
extension := extensionsMap[item.Protocol]
fmt.Printf("extension: %+v\n", extension)
var req *http.Request var req *http.Request
marshedReq, _ := json.Marshal(item.Data.Request.Orig) marshedReq, _ := json.Marshal(item.Data.Request.Orig)
json.Unmarshal(marshedReq, &req) json.Unmarshal(marshedReq, &req)

View File

@ -44,7 +44,7 @@ type RequestResponsePair struct {
} }
type OutputChannelItem struct { type OutputChannelItem struct {
Type string Protocol string
Timestamp int64 Timestamp int64
ConnectionInfo *ConnectionInfo ConnectionInfo *ConnectionInfo
Data *RequestResponsePair Data *RequestResponsePair

View File

@ -86,7 +86,7 @@ func (matcher *requestResponseMatcher) registerResponse(ident string, response *
func (matcher *requestResponseMatcher) preparePair(requestHTTPMessage *api.GenericMessage, responseHTTPMessage *api.GenericMessage) *api.OutputChannelItem { func (matcher *requestResponseMatcher) preparePair(requestHTTPMessage *api.GenericMessage, responseHTTPMessage *api.GenericMessage) *api.OutputChannelItem {
return &api.OutputChannelItem{ return &api.OutputChannelItem{
Type: ExtensionName, Protocol: ExtensionName,
Timestamp: time.Now().UnixNano() / int64(time.Millisecond), Timestamp: time.Now().UnixNano() / int64(time.Millisecond),
ConnectionInfo: nil, ConnectionInfo: nil,
Data: &api.RequestResponsePair{ Data: &api.RequestResponsePair{

View File

@ -12,13 +12,9 @@ import (
"encoding/hex" "encoding/hex"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"os" "os"
"os/signal" "os/signal"
"path"
"path/filepath"
"plugin"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"strconv" "strconv"
@ -127,8 +123,6 @@ var nErrors uint
var ownIps []string // global var ownIps []string // global
var hostMode bool // global var hostMode bool // global
var extensions []*api.Extension // global var extensions []*api.Extension // global
var allOutboundPorts []string // global
var allInboundPorts []string // global
/* minOutputLevel: Error will be printed only if outputLevel is above this value /* minOutputLevel: Error will be printed only if outputLevel is above this value
* t: key for errorsMap (counting errors) * t: key for errorsMap (counting errors)
@ -192,8 +186,9 @@ func (c *Context) GetCaptureInfo() gopacket.CaptureInfo {
return c.CaptureInfo return c.CaptureInfo
} }
func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem) { func StartPassiveTapper(opts *TapOpts, outputItems chan *api.OutputChannelItem, extensionsRef []*api.Extension) {
hostMode = opts.HostMode hostMode = opts.HostMode
extensions = extensionsRef
if GetMemoryProfilingEnabled() { if GetMemoryProfilingEnabled() {
startMemoryProfiler() startMemoryProfiler()
@ -233,56 +228,7 @@ func startMemoryProfiler() {
}() }()
} }
func MergeUnique(slice []string, merge []string) []string {
for _, i := range merge {
add := true
for _, ele := range slice {
if ele == i {
add = false
}
}
if add {
slice = append(slice, i)
}
}
return slice
}
func loadExtensions() {
dir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
extensionsDir := path.Join(dir, "./extensions/")
files, err := ioutil.ReadDir(extensionsDir)
if err != nil {
log.Fatal(err)
}
extensions = make([]*api.Extension, len(files))
for i, file := range files {
filename := file.Name()
log.Printf("Loading extension: %s\n", filename)
extension := &api.Extension{
Path: path.Join(extensionsDir, filename),
}
plug, _ := plugin.Open(extension.Path)
extension.Plug = plug
symDissector, _ := plug.Lookup("Dissector")
var dissector api.Dissector
dissector, _ = symDissector.(api.Dissector)
dissector.Register(extension)
extension.Dissector = dissector
log.Printf("Extension Properties: %+v\n", extension)
extensions[i] = extension
allOutboundPorts = MergeUnique(allOutboundPorts, extension.OutboundPorts)
allInboundPorts = MergeUnique(allInboundPorts, extension.InboundPorts)
}
log.Printf("allOutboundPorts: %v\n", allOutboundPorts)
log.Printf("allInboundPorts: %v\n", allInboundPorts)
}
func startPassiveTapper(outputItems chan *api.OutputChannelItem) { func startPassiveTapper(outputItems chan *api.OutputChannelItem) {
loadExtensions()
log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile) log.SetFlags(log.LstdFlags | log.LUTC | log.Lshortfile)
defer util.Run()() defer util.Run()()