diff --git a/pkg/kubectl/cmd/cp.go b/pkg/kubectl/cmd/cp.go index 83ab0e26727..aa4cfe881ab 100644 --- a/pkg/kubectl/cmd/cp.go +++ b/pkg/kubectl/cmd/cp.go @@ -23,6 +23,7 @@ import ( "io/ioutil" "os" "path" + "path/filepath" "strings" "k8s.io/kubernetes/pkg/kubectl/cmd/templates" @@ -236,6 +237,8 @@ func recursiveTar(base, file string, tw *tar.Writer) error { } func untarAll(reader io.Reader, destFile, prefix string) error { + entrySeq := -1 + // TODO: use compression here? tarReader := tar.NewReader(reader) for { @@ -246,25 +249,38 @@ func untarAll(reader io.Reader, destFile, prefix string) error { } break } + entrySeq++ outFileName := path.Join(destFile, header.Name[len(prefix):]) baseName := path.Dir(outFileName) if err := os.MkdirAll(baseName, 0755); err != nil { return err } if header.FileInfo().IsDir() { - os.MkdirAll(outFileName, 0755) + + if err := os.MkdirAll(outFileName, 0755); err != nil { + return err + } continue } + + // handle coping remote file into local directory + if entrySeq == 0 && !header.FileInfo().IsDir() { + exists, err := dirExists(outFileName) + if err != nil { + return err + } + if exists { + outFileName = filepath.Join(outFileName, path.Base(header.Name)) + } + } outFile, err := os.Create(outFileName) if err != nil { return err } + defer outFile.Close() if _, err := io.Copy(outFile, tarReader); err != nil { return err } - if err := outFile.Close(); err != nil { - return err - } } return nil } @@ -312,3 +328,15 @@ func execute(f cmdutil.Factory, cmd *cobra.Command, options *ExecOptions) error } return nil } + +// dirExists checks if a path exists and is a directory. +func dirExists(path string) (bool, error) { + fi, err := os.Stat(path) + if err == nil && fi.IsDir() { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} diff --git a/pkg/kubectl/cmd/cp_test.go b/pkg/kubectl/cmd/cp_test.go index 7f2cdd3ea56..d3d86ef8fc9 100644 --- a/pkg/kubectl/cmd/cp_test.go +++ b/pkg/kubectl/cmd/cp_test.go @@ -21,7 +21,9 @@ import ( "io" "io/ioutil" "os" + "os/exec" "path" + "path/filepath" "testing" ) @@ -179,3 +181,114 @@ func TestTarUntar(t *testing.T) { } } } + +// TestCopyToLocalFileOrDir tests untarAll in two cases : +// 1: copy pod file to local file +// 2: copy pod file into local directory +func TestCopyToLocalFileOrDir(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "input") + dir2, err2 := ioutil.TempDir(os.TempDir(), "output") + if err != nil || err2 != nil { + t.Errorf("unexpected error: %v | %v", err, err2) + t.FailNow() + } + defer func() { + if err := os.RemoveAll(dir); err != nil { + t.Errorf("Unexpected error cleaning up: %v", err) + } + if err := os.RemoveAll(dir2); err != nil { + t.Errorf("Unexpected error cleaning up: %v", err) + } + }() + + files := []struct { + name string + data string + dest string + destDirExists bool + }{ + { + name: "foo", + data: "foobarbaz", + dest: "path/to/dest", + destDirExists: false, + }, + { + name: "dir/blah", + data: "bazblahfoo", + dest: "dest/file/path", + destDirExists: true, + }, + } + + for _, file := range files { + func() { + // setup + srcFilePath := filepath.Join(dir, file.name) + destPath := filepath.Join(dir2, file.dest) + if err := os.MkdirAll(filepath.Dir(srcFilePath), 0755); err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + srcFile, err := os.Create(srcFilePath) + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + defer srcFile.Close() + + if _, err := io.Copy(srcFile, bytes.NewBuffer([]byte(file.data))); err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + if file.destDirExists { + if err := os.MkdirAll(destPath, 0755); err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + } + + // start tests + srcTarFilePath := filepath.Join(dir, file.name+".tar") + // here use tar command to create tar file instead of calling makeTar func + // because makeTar func can not generate correct header name + err = exec.Command("tar", "cf", srcTarFilePath, srcFilePath).Run() + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + srcTarFile, err := os.Open(srcTarFilePath) + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + defer srcTarFile.Close() + + if err := untarAll(srcTarFile, destPath, getPrefix(srcFilePath)); err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + + actualDestFilePath := destPath + if file.destDirExists { + actualDestFilePath = filepath.Join(destPath, filepath.Base(srcFilePath)) + } + _, err = os.Stat(actualDestFilePath) + if err != nil && os.IsNotExist(err) { + t.Errorf("expecting %s exists, but actually it's missing", actualDestFilePath) + } + destFile, err := os.Open(actualDestFilePath) + if err != nil { + t.Errorf("unexpected error: %v", err) + t.FailNow() + } + defer destFile.Close() + buff := &bytes.Buffer{} + io.Copy(buff, destFile) + if file.data != string(buff.Bytes()) { + t.Errorf("expected: %s, actual: %s", file.data, string(buff.Bytes())) + } + }() + } + +}