diff --git a/cmd/tap.go b/cmd/tap.go index 5191deec0..60ed0c3ff 100644 --- a/cmd/tap.go +++ b/cmd/tap.go @@ -56,7 +56,7 @@ func init() { tapCmd.Flags().String(configStructs.StorageLimitLabel, defaultTapConfig.StorageLimit, "Override the default storage limit (per node)") tapCmd.Flags().String(configStructs.StorageClassLabel, defaultTapConfig.StorageClass, "Override the default storage class of the PersistentVolumeClaim (per node)") tapCmd.Flags().Bool(configStructs.DryRunLabel, defaultTapConfig.DryRun, "Preview of all pods matching the regex, without tapping them") - tapCmd.Flags().StringP(configStructs.PcapLabel, "p", defaultTapConfig.Pcap, fmt.Sprintf("Capture from a PCAP snapshot of %s (.tar.gz) using your Docker Daemon instead of Kubernetes", misc.Software)) + tapCmd.Flags().StringP(configStructs.PcapLabel, "p", defaultTapConfig.Pcap, fmt.Sprintf("Capture from a PCAP snapshot of %s (.tar.gz) using your Docker Daemon instead of Kubernetes. TAR path, S3 URL (object or bucket)", misc.Software)) tapCmd.Flags().Bool(configStructs.ServiceMeshLabel, defaultTapConfig.ServiceMesh, "Capture the encrypted traffic if the cluster is configured with a service mesh and with mTLS") tapCmd.Flags().Bool(configStructs.TlsLabel, defaultTapConfig.Tls, "Capture the traffic that's encrypted with OpenSSL or Go crypto/tls libraries") tapCmd.Flags().Bool(configStructs.IgnoreTaintedLabel, defaultTapConfig.IgnoreTainted, "Ignore tainted pods while running Worker DaemonSet") diff --git a/cmd/tapPcapRunner.go b/cmd/tapPcapRunner.go index 86db43a7d..bcbd8937f 100644 --- a/cmd/tapPcapRunner.go +++ b/cmd/tapPcapRunner.go @@ -1,13 +1,16 @@ package cmd import ( + "archive/tar" "bufio" + "compress/gzip" "context" "encoding/json" "fmt" "io" "net/url" "os" + "path/filepath" "strings" "github.com/aws/aws-sdk-go-v2/aws" @@ -284,35 +287,145 @@ func downloadTarFromS3(s3Url string) (tarPath string, err error) { return } + bucket := u.Host + key := u.Path[1:] + var cfg aws.Config cfg, err = awsConfig.LoadDefaultConfig(context.TODO()) if err != nil { return } - var file *os.File - file, err = os.CreateTemp(os.TempDir(), "kubeshark_s3_*.tar.gz") - if err != nil { - return - } - - log.Info().Str("bucket", u.Host).Str("key", u.Path[1:]).Msg("Downloading from S3") - client := s3.NewFromConfig(cfg) - downloader := manager.NewDownloader(client) - _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ - Bucket: aws.String(u.Host), - Key: aws.String(u.Path[1:]), + + var listObjectsOutput *s3.ListObjectsV2Output + listObjectsOutput, err = client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), }) if err != nil { return } - tarPath = file.Name() + if key == "" { + var tempDirPath string + tempDirPath, err = os.MkdirTemp(os.TempDir(), "kubeshark_*") + if err != nil { + return + } + + for _, object := range listObjectsOutput.Contents { + key = *object.Key + fullPath := filepath.Join(tempDirPath, key) + err = os.MkdirAll(filepath.Dir(fullPath), os.ModePerm) + if err != nil { + return + } + + var file *os.File + file, err = os.Create(fullPath) + if err != nil { + return + } + + log.Info().Str("bucket", bucket).Str("key", key).Msg("Downloading from S3") + + downloader := manager.NewDownloader(client) + _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return + } + } + + tarPath, err = tarDirectory(tempDirPath) + } else { + var file *os.File + file, err = os.CreateTemp(os.TempDir(), filepath.Base(key)) + if err != nil { + return + } + + log.Info().Str("bucket", bucket).Str("key", key).Msg("Downloading from S3") + + downloader := manager.NewDownloader(client) + _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return + } + + tarPath = file.Name() + } return } +func tarDirectory(dirPath string) (string, error) { + tarPath := fmt.Sprintf("%s.tar.gz", dirPath) + + var file *os.File + file, err := os.Create(tarPath) + if err != nil { + return "", err + } + defer file.Close() + + gzipWriter := gzip.NewWriter(file) + defer gzipWriter.Close() + + tarWriter := tar.NewWriter(gzipWriter) + defer tarWriter.Close() + + walker := func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return err + } + + header := &tar.Header{ + Name: path[len(dirPath)+1:], + Size: stat.Size(), + Mode: int64(stat.Mode()), + ModTime: stat.ModTime(), + } + + err = tarWriter.WriteHeader(header) + if err != nil { + return err + } + + _, err = io.Copy(tarWriter, file) + if err != nil { + return err + } + + return nil + } + + err = filepath.Walk(dirPath, walker) + if err != nil { + return "", err + } + + return tarPath, nil +} + func pcap(tarPath string) error { if strings.HasPrefix(tarPath, "s3://") { var err error