diff --git a/cmd/cloudinitsave/cloudinitsave.go b/cmd/cloudinitsave/cloudinitsave.go index e2568083..aecb2913 100644 --- a/cmd/cloudinitsave/cloudinitsave.go +++ b/cmd/cloudinitsave/cloudinitsave.go @@ -38,6 +38,7 @@ import ( "github.com/rancher/os/config/cloudinit/datasource/metadata/gce" "github.com/rancher/os/config/cloudinit/datasource/metadata/packet" "github.com/rancher/os/config/cloudinit/datasource/proccmdline" + "github.com/rancher/os/config/cloudinit/datasource/tftp" "github.com/rancher/os/config/cloudinit/datasource/url" "github.com/rancher/os/config/cloudinit/datasource/vmware" "github.com/rancher/os/config/cloudinit/pkg" @@ -237,6 +238,8 @@ func getDatasources(datasources []string) []datasource.Datasource { if root != "" { dss = append(dss, file.NewDatasource(root)) } + case "tftp": + dss = append(dss, tftp.NewDatasource(root)) case "url": if root != "" { dss = append(dss, url.NewDatasource(root)) diff --git a/config/cloudinit/datasource/tftp/tftp.go b/config/cloudinit/datasource/tftp/tftp.go new file mode 100644 index 00000000..d7e3e969 --- /dev/null +++ b/config/cloudinit/datasource/tftp/tftp.go @@ -0,0 +1,83 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "regexp" + "strings" + + "github.com/rancher/os/config/cloudinit/datasource" + + "github.com/pin/tftp" +) + +type Client interface { + Receive(filename string, mode string) (io.WriterTo, error) +} + +type RemoteFile struct { + host string + path string + client Client + stream io.WriterTo + lastError error +} + +func NewDatasource(hostAndPath string) *RemoteFile { + parts := strings.SplitN(hostAndPath, "/", 2) + + if len(parts) < 2 { + return &RemoteFile{hostAndPath, "", nil, nil, nil} + } + + host := parts[0] + if match, _ := regexp.MatchString(":[0-9]{2,5}$", host); !match { + // No port, using default port 69 + host += ":69" + } + + path := parts[1] + if client, lastError := tftp.NewClient(host); lastError == nil { + return &RemoteFile{host, path, client, nil, nil} + } + + return &RemoteFile{host, path, nil, nil, nil} +} + +func (f *RemoteFile) IsAvailable() bool { + f.stream, f.lastError = f.client.Receive(f.path, "octet") + return f.lastError == nil +} + +func (f *RemoteFile) Finish() error { + return nil +} + +func (f *RemoteFile) String() string { + return fmt.Sprintf("%s: %s%s (lastError: %v)", f.Type(), f.host, f.path, f.lastError) +} + +func (f *RemoteFile) AvailabilityChanges() bool { + return false +} + +func (f *RemoteFile) ConfigRoot() string { + return "" +} + +func (f *RemoteFile) FetchMetadata() (datasource.Metadata, error) { + return datasource.Metadata{}, nil +} + +func (f *RemoteFile) FetchUserdata() ([]byte, error) { + var b bytes.Buffer + + _, err := f.stream.WriteTo(&b) + + return b.Bytes(), err +} + +func (f *RemoteFile) Type() string { + return "tftp" +} diff --git a/config/cloudinit/datasource/tftp/tftp_test.go b/config/cloudinit/datasource/tftp/tftp_test.go new file mode 100644 index 00000000..c9722327 --- /dev/null +++ b/config/cloudinit/datasource/tftp/tftp_test.go @@ -0,0 +1,92 @@ +package tftp + +import ( + "fmt" + "io" + "reflect" + "testing" +) + +type mockClient struct { +} + +type mockReceiver struct { +} + +func (r mockReceiver) WriteTo(w io.Writer) (n int64, err error) { + b := []byte("cloud-config file") + w.Write(b) + return int64(len(b)), nil +} + +func (c mockClient) Receive(filename string, mode string) (io.WriterTo, error) { + if filename == "does-not-exist" { + return &mockReceiver{}, fmt.Errorf("does not exist") + } + return &mockReceiver{}, nil +} + +var _ Client = (*mockClient)(nil) + +func TestNewDatasource(t *testing.T) { + for _, tt := range []struct { + root string + expectHost string + expectPath string + }{ + { + root: "127.0.0.1/test/file.yaml", + expectHost: "127.0.0.1:69", + expectPath: "test/file.yaml", + }, + { + root: "127.0.0.1/test/file.yaml", + expectHost: "127.0.0.1:69", + expectPath: "test/file.yaml", + }, + } { + ds := NewDatasource(tt.root) + if ds.host != tt.expectHost || ds.path != tt.expectPath { + t.Fatalf("bad host or path (%q): want host=%s, got %s, path=%s, got %s", tt.root, tt.expectHost, ds.host, tt.expectPath, ds.path) + } + } +} + +func TestIsAvailable(t *testing.T) { + for _, tt := range []struct { + remoteFile *RemoteFile + expect bool + }{ + { + remoteFile: &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil}, + expect: true, + }, + { + remoteFile: &RemoteFile{"1.2.3.4", "does-not-exist", &mockClient{}, nil, nil}, + expect: false, + }, + } { + if tt.remoteFile.IsAvailable() != tt.expect { + t.Fatalf("expected remote file %s to be %v", tt.remoteFile.path, tt.expect) + } + } +} + +func TestFetchUserdata(t *testing.T) { + rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, &mockReceiver{}, nil} + b, _ := rf.FetchUserdata() + + expect := []byte("cloud-config file") + + if len(b) != len(expect) || !reflect.DeepEqual(b, expect) { + t.Fatalf("expected length of buffer to be %d was %d. Expected %s, got %s", len(expect), len(b), string(expect), string(b)) + } +} + +func TestType(t *testing.T) { + rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil} + + if rf.Type() != "tftp" { + t.Fatalf("expected remote file Type() to return %s got %s", "tftp", rf.Type()) + } +} diff --git a/trash.conf b/trash.conf index 591d7449..bbf3b98c 100644 --- a/trash.conf +++ b/trash.conf @@ -57,3 +57,4 @@ golang.org/x/sys eb2c74142fd19a79b3f237334c7384d5167b1b46 https://github.com/gol google.golang.org/grpc ab0be5212fb225475f2087566eded7da5d727960 https://github.com/grpc/grpc-go.git gopkg.in/fsnotify.v1 v1.2.0 github.com/fatih/structs dc3312cb1a4513a366c4c9e622ad55c32df12ed3 +github.com/pin/tftp v2.1.0 diff --git a/vendor/github.com/pin/tftp/.gitignore b/vendor/github.com/pin/tftp/.gitignore new file mode 100644 index 00000000..daf913b1 --- /dev/null +++ b/vendor/github.com/pin/tftp/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/pin/tftp/.travis.yml b/vendor/github.com/pin/tftp/.travis.yml new file mode 100644 index 00000000..edbec9a6 --- /dev/null +++ b/vendor/github.com/pin/tftp/.travis.yml @@ -0,0 +1,8 @@ +language: go + +os: + - linux + - osx + +before_install: + - ulimit -n 4096 diff --git a/vendor/github.com/pin/tftp/CONTRIBUTORS b/vendor/github.com/pin/tftp/CONTRIBUTORS new file mode 100644 index 00000000..c8c331de --- /dev/null +++ b/vendor/github.com/pin/tftp/CONTRIBUTORS @@ -0,0 +1,4 @@ +Dmitri Popov +Mojo Talantikite +Giovanni Bajo +Andrew Danforth diff --git a/vendor/github.com/pin/tftp/LICENSE b/vendor/github.com/pin/tftp/LICENSE new file mode 100644 index 00000000..dada3d09 --- /dev/null +++ b/vendor/github.com/pin/tftp/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) +Copyright (c) 2016 Dmitri Popov + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/pin/tftp/README.md b/vendor/github.com/pin/tftp/README.md new file mode 100644 index 00000000..e360cc88 --- /dev/null +++ b/vendor/github.com/pin/tftp/README.md @@ -0,0 +1,171 @@ +TFTP server and client library for Golang +========================================= + +[![GoDoc](https://godoc.org/github.com/pin/tftp?status.svg)](https://godoc.org/github.com/pin/tftp) +[![Build Status](https://travis-ci.org/pin/tftp.svg?branch=master)](https://travis-ci.org/pin/tftp) + +Implements: + * [RFC 1350](https://tools.ietf.org/html/rfc1350) - The TFTP Protocol (Revision 2) + * [RFC 2347](https://tools.ietf.org/html/rfc2347) - TFTP Option Extension + * [RFC 2348](https://tools.ietf.org/html/rfc2348) - TFTP Blocksize Option + +Partially implements (tsize server side only): + * [RFC 2349](https://tools.ietf.org/html/rfc2349) - TFTP Timeout Interval and Transfer Size Options + +Set of features is sufficient for PXE boot support. + +``` go +import "github.com/pin/tftp" +``` + +The package is cohesive to Golang `io`. Particularly it implements +`io.ReaderFrom` and `io.WriterTo` interfaces. That allows efficient data +transmission without unnecessary memory copying and allocations. + + +TFTP Server +----------- + +```go + +// readHandler is called when client starts file download from server +func readHandler(filename string, rf io.ReaderFrom) error { + file, err := os.Open(filename) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + return err + } + n, err := rf.ReadFrom(file) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + return err + } + fmt.Printf("%d bytes sent\n", n) + return nil +} + +// writeHandler is called when client starts file upload to server +func writeHandler(filename string, wt io.WriterTo) error { + file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + return err + } + n, err := wt.WriteTo(file) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + return err + } + fmt.Printf("%d bytes received\n", n) + return nil +} + +func main() { + // use nil in place of handler to disable read or write operations + s := tftp.NewServer(readHandler, writeHandler) + s.SetTimeout(5 * time.Second) // optional + err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called + if err != nil { + fmt.Fprintf(os.Stdout, "server: %v\n", err) + os.Exit(1) + } +} +``` + +TFTP Client +----------- +Upload file to server: + +```go +c, err := tftp.NewClient("172.16.4.21:69") +file, err := os.Open(path) +c.SetTimeout(5 * time.Second) // optional +rf, err := c.Send("foobar.txt", "octet") +n, err := rf.ReadFrom(file) +fmt.Printf("%d bytes sent\n", n) +``` + +Download file from server: + +```go +c, err := tftp.NewClient("172.16.4.21:69") +wt, err := c.Receive("foobar.txt", "octet") +file, err := os.Create(path) +// Optionally obtain transfer size before actual data. +if n, ok := wt.(IncomingTransfer).Size(); ok { + fmt.Printf("Transfer size: %d\n", n) +} +n, err := wt.WriteTo(file) +fmt.Printf("%d bytes received\n", n) +``` + +Note: please handle errors better :) + +TSize option +------------ + +PXE boot ROM often expects tsize option support from a server: client +(e.g. computer that boots over the network) wants to know size of a +download before the actual data comes. Server has to obtain stream +size and send it to a client. + +Often it will happen automatically because TFTP library tries to check +if `io.Reader` provided to `ReadFrom` method also satisfies +`io.Seeker` interface (`os.File` for instance) and uses `Seek` to +determine file size. + +In case `io.Reader` you provide to `ReadFrom` in read handler does not +satisfy `io.Seeker` interface or you do not want TFTP library to call +`Seek` on your reader but still want to respond with tsize option +during outgoing request you can use an `OutgoingTransfer` interface: + +```go + +func readHandler(filename string, rf io.ReaderFrom) error { + ... + // Set transfer size before calling ReadFrom. + rf.(tftp.OutgoingTransfer).SetSize(myFileSize) + ... + // ReadFrom ... + +``` + +Similarly, it is possible to obtain size of a file that is about to be +received using `IncomingTransfer` interface (see `Size` method). + +Remote Address +-------------- + +The `OutgoingTransfer` and `IncomingTransfer` interfaces also provide the +`RemoteAddr` method which returns the peer IP address and port as a +`net.UDPAddr`. This can be used for detailed logging in a server handler. + +```go + +func readHandler(filename string, rf io.ReaderFrom) error { + ... + raddr := rf.(tftp.OutgoingTransfer).RemoteAddr() + log.Println("RRQ from", raddr.String()) + ... + // ReadFrom ... +``` + +Backoff +------- + +The default backoff before retransmitting an unacknowledged packet is a +random duration between 0 and 1 second. This behavior can be overridden +in clients and servers by providing a custom backoff calculation function. + +```go + s := tftp.NewServer(readHandler, writeHandler) + s.SetBackoff(func (attempts int) time.Duration { + return time.Duration(attempts) * time.Second + }) +``` + +or, for no backoff + +```go + s.SetBackoff(func (int) time.Duration { return 0 }) +``` diff --git a/vendor/github.com/pin/tftp/backoff.go b/vendor/github.com/pin/tftp/backoff.go new file mode 100644 index 00000000..5c5326c2 --- /dev/null +++ b/vendor/github.com/pin/tftp/backoff.go @@ -0,0 +1,35 @@ +package tftp + +import ( + "math/rand" + "time" +) + +const ( + defaultTimeout = 5 * time.Second + defaultRetries = 5 +) + +type backoffFunc func(int) time.Duration + +type backoff struct { + attempt int + handler backoffFunc +} + +func (b *backoff) reset() { + b.attempt = 0 +} + +func (b *backoff) count() int { + return b.attempt +} + +func (b *backoff) backoff() { + if b.handler == nil { + time.Sleep(time.Duration(rand.Int63n(int64(time.Second)))) + } else { + time.Sleep(b.handler(b.attempt)) + } + b.attempt++ +} diff --git a/vendor/github.com/pin/tftp/client.go b/vendor/github.com/pin/tftp/client.go new file mode 100644 index 00000000..9f1802de --- /dev/null +++ b/vendor/github.com/pin/tftp/client.go @@ -0,0 +1,125 @@ +package tftp + +import ( + "fmt" + "io" + "net" + "strconv" + "time" +) + +// NewClient creates TFTP client for server on address provided. +func NewClient(addr string) (*Client, error) { + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, fmt.Errorf("resolving address %s: %v", addr, err) + } + return &Client{ + addr: a, + timeout: defaultTimeout, + retries: defaultRetries, + }, nil +} + +// SetTimeout sets maximum time client waits for single network round-trip to succeed. +// Default is 5 seconds. +func (c *Client) SetTimeout(t time.Duration) { + if t <= 0 { + c.timeout = defaultTimeout + } + c.timeout = t +} + +// SetRetries sets maximum number of attempts client made to transmit a packet. +// Default is 5 attempts. +func (c *Client) SetRetries(count int) { + if count < 1 { + c.retries = defaultRetries + } + c.retries = count +} + +// SetBackoff sets a user provided function that is called to provide a +// backoff duration prior to retransmitting an unacknowledged packet. +func (c *Client) SetBackoff(h backoffFunc) { + c.backoff = h +} + +type Client struct { + addr *net.UDPAddr + timeout time.Duration + retries int + backoff backoffFunc + blksize int + tsize bool +} + +// Send starts outgoing file transmission. It returns io.ReaderFrom or error. +func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) { + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return nil, err + } + s := &sender{ + send: make([]byte, datagramLength), + receive: make([]byte, datagramLength), + conn: conn, + retry: &backoff{handler: c.backoff}, + timeout: c.timeout, + retries: c.retries, + addr: c.addr, + mode: mode, + } + if c.blksize != 0 { + s.opts = make(options) + s.opts["blksize"] = strconv.Itoa(c.blksize) + } + n := packRQ(s.send, opWRQ, filename, mode, s.opts) + addr, err := s.sendWithRetry(n) + if err != nil { + return nil, err + } + s.addr = addr + s.opts = nil + return s, nil +} + +// Receive starts incoming file transmission. It returns io.WriterTo or error. +func (c Client) Receive(filename string, mode string) (io.WriterTo, error) { + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return nil, err + } + if c.timeout == 0 { + c.timeout = defaultTimeout + } + r := &receiver{ + send: make([]byte, datagramLength), + receive: make([]byte, datagramLength), + conn: conn, + retry: &backoff{handler: c.backoff}, + timeout: c.timeout, + retries: c.retries, + addr: c.addr, + autoTerm: true, + block: 1, + mode: mode, + } + if c.blksize != 0 || c.tsize { + r.opts = make(options) + } + if c.blksize != 0 { + r.opts["blksize"] = strconv.Itoa(c.blksize) + } + if c.tsize { + r.opts["tsize"] = "0" + } + n := packRQ(r.send, opRRQ, filename, mode, r.opts) + l, addr, err := r.receiveWithRetry(n) + if err != nil { + return nil, err + } + r.l = l + r.addr = addr + return r, nil +} diff --git a/vendor/github.com/pin/tftp/netascii/netascii.go b/vendor/github.com/pin/tftp/netascii/netascii.go new file mode 100644 index 00000000..92cee03b --- /dev/null +++ b/vendor/github.com/pin/tftp/netascii/netascii.go @@ -0,0 +1,108 @@ +package netascii + +// TODO: make it work not only on linux + +import "io" + +const ( + CR = '\x0d' + LF = '\x0a' + NUL = '\x00' +) + +func ToReader(r io.Reader) io.Reader { + return &toReader{ + r: r, + buf: make([]byte, 256), + } +} + +type toReader struct { + r io.Reader + buf []byte + n int + i int + err error + lf bool + nul bool +} + +func (r *toReader) Read(p []byte) (int, error) { + var n int + for n < len(p) { + if r.lf { + p[n] = LF + n++ + r.lf = false + continue + } + if r.nul { + p[n] = NUL + n++ + r.nul = false + continue + } + if r.i < r.n { + if r.buf[r.i] == LF { + p[n] = CR + r.lf = true + } else if r.buf[r.i] == CR { + p[n] = CR + r.nul = true + + } else { + p[n] = r.buf[r.i] + } + r.i++ + n++ + continue + } + if r.err == nil { + r.n, r.err = r.r.Read(r.buf) + r.i = 0 + } else { + return n, r.err + } + } + return n, r.err +} + +type fromWriter struct { + w io.Writer + buf []byte + i int + cr bool +} + +func FromWriter(w io.Writer) io.Writer { + return &fromWriter{ + w: w, + buf: make([]byte, 256), + } +} + +func (w *fromWriter) Write(p []byte) (n int, err error) { + for n < len(p) { + if w.cr { + if p[n] == LF { + w.buf[w.i] = LF + } + if p[n] == NUL { + w.buf[w.i] = CR + } + w.cr = false + w.i++ + } else if p[n] == CR { + w.cr = true + } else { + w.buf[w.i] = p[n] + w.i++ + } + n++ + if w.i == len(w.buf) || n == len(p) { + _, err = w.w.Write(w.buf[:w.i]) + w.i = 0 + } + } + return n, err +} diff --git a/vendor/github.com/pin/tftp/netascii/netascii_test.go b/vendor/github.com/pin/tftp/netascii/netascii_test.go new file mode 100644 index 00000000..7dab8cc0 --- /dev/null +++ b/vendor/github.com/pin/tftp/netascii/netascii_test.go @@ -0,0 +1,89 @@ +package netascii + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" + "testing/iotest" +) + +var basic = map[string]string{ + "\r": "\r\x00", + "\n": "\r\n", + "la\nbu": "la\r\nbu", + "la\rbu": "la\r\x00bu", + "\r\r\r": "\r\x00\r\x00\r\x00", + "\n\n\n": "\r\n\r\n\r\n", +} + +func TestTo(t *testing.T) { + for text, netascii := range basic { + to := ToReader(strings.NewReader(text)) + n, _ := ioutil.ReadAll(to) + if bytes.Compare(n, []byte(netascii)) != 0 { + t.Errorf("%q to netascii: %q != %q", text, n, netascii) + } + } +} + +func TestFrom(t *testing.T) { + for text, netascii := range basic { + r := bytes.NewReader([]byte(netascii)) + b := &bytes.Buffer{} + from := FromWriter(b) + r.WriteTo(from) + n, _ := ioutil.ReadAll(b) + if string(n) != text { + t.Errorf("%q from netascii: %q != %q", netascii, n, text) + } + } +} + +const text = ` +Therefore, the sequence "CR LF" must be treated as a single "new +line" character and used whenever their combined action is +intended; the sequence "CR NUL" must be used where a carriage +return alone is actually desired; and the CR character must be +avoided in other contexts. This rule gives assurance to systems +which must decide whether to perform a "new line" function or a +multiple-backspace that the TELNET stream contains a character +following a CR that will allow a rational decision. +(in the default ASCII mode), to preserve the symmetry of the +NVT model. Even though it may be known in some situations +(e.g., with remote echo and suppress go ahead options in +effect) that characters are not being sent to an actual +printer, nonetheless, for the sake of consistency, the protocol +requires that a NUL be inserted following a CR not followed by +a LF in the data stream. The converse of this is that a NUL +received in the data stream after a CR (in the absence of +options negotiations which explicitly specify otherwise) should +be stripped out prior to applying the NVT to local character +set mapping. +` + +func TestWriteRead(t *testing.T) { + var one bytes.Buffer + to := ToReader(strings.NewReader(text)) + one.ReadFrom(to) + two := &bytes.Buffer{} + from := FromWriter(two) + one.WriteTo(from) + text2, _ := ioutil.ReadAll(two) + if text != string(text2) { + t.Errorf("text mismatch \n%x \n%x", text, text2) + } +} + +func TestOneByte(t *testing.T) { + var one bytes.Buffer + to := iotest.OneByteReader(ToReader(strings.NewReader(text))) + one.ReadFrom(to) + two := &bytes.Buffer{} + from := FromWriter(two) + one.WriteTo(from) + text2, _ := ioutil.ReadAll(two) + if text != string(text2) { + t.Errorf("text mismatch \n%x \n%x", text, text2) + } +} diff --git a/vendor/github.com/pin/tftp/packet.go b/vendor/github.com/pin/tftp/packet.go new file mode 100644 index 00000000..1ac77428 --- /dev/null +++ b/vendor/github.com/pin/tftp/packet.go @@ -0,0 +1,190 @@ +package tftp + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +const ( + opRRQ = uint16(1) // Read request (RRQ) + opWRQ = uint16(2) // Write request (WRQ) + opDATA = uint16(3) // Data + opACK = uint16(4) // Acknowledgement + opERROR = uint16(5) // Error + opOACK = uint16(6) // Options Acknowledgment +) + +const ( + blockLength = 512 + datagramLength = 516 +) + +type options map[string]string + +// RRQ/WRQ packet +// +// 2 bytes string 1 byte string 1 byte +// -------------------------------------------------- +// | Opcode | Filename | 0 | Mode | 0 | +// -------------------------------------------------- +type pRRQ []byte +type pWRQ []byte + +// packRQ returns length of the packet in b +func packRQ(p []byte, op uint16, filename, mode string, opts options) int { + binary.BigEndian.PutUint16(p, op) + n := 2 + n += copy(p[2:len(p)-10], filename) + p[n] = 0 + n++ + n += copy(p[n:], mode) + p[n] = 0 + n++ + for name, value := range opts { + n += copy(p[n:], name) + p[n] = 0 + n++ + n += copy(p[n:], value) + p[n] = 0 + n++ + } + return n +} + +func unpackRQ(p []byte) (filename, mode string, opts options, err error) { + bs := bytes.Split(p[2:], []byte{0}) + if len(bs) < 2 { + return "", "", nil, fmt.Errorf("missing filename or mode") + } + filename = string(bs[0]) + mode = string(bs[1]) + if len(bs) < 4 { + return filename, mode, nil, nil + } + opts = make(options) + for i := 2; i+1 < len(bs); i += 2 { + opts[string(bs[i])] = string(bs[i+1]) + } + return filename, mode, opts, nil +} + +// OACK packet +// +// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ +// | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 | +// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ +type pOACK []byte + +func packOACK(p []byte, opts options) int { + binary.BigEndian.PutUint16(p, opOACK) + n := 2 + for name, value := range opts { + n += copy(p[n:], name) + p[n] = 0 + n++ + n += copy(p[n:], value) + p[n] = 0 + n++ + } + return n +} + +func unpackOACK(p []byte) (opts options, err error) { + bs := bytes.Split(p[2:], []byte{0}) + opts = make(options) + for i := 0; i+1 < len(bs); i += 2 { + opts[string(bs[i])] = string(bs[i+1]) + } + return opts, nil +} + +// ERROR packet +// +// 2 bytes 2 bytes string 1 byte +// ------------------------------------------ +// | Opcode | ErrorCode | ErrMsg | 0 | +// ------------------------------------------ +type pERROR []byte + +func packERROR(p []byte, code uint16, message string) int { + binary.BigEndian.PutUint16(p, opERROR) + binary.BigEndian.PutUint16(p[2:], code) + n := copy(p[4:len(p)-2], message) + p[4+n] = 0 + return n + 5 +} + +func (p pERROR) code() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func (p pERROR) message() string { + return string(p[4:]) +} + +// DATA packet +// +// 2 bytes 2 bytes n bytes +// ---------------------------------- +// | Opcode | Block # | Data | +// ---------------------------------- +type pDATA []byte + +func (p pDATA) block() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +// ACK packet +// +// 2 bytes 2 bytes +// ----------------------- +// | Opcode | Block # | +// ----------------------- +type pACK []byte + +func (p pACK) block() uint16 { + return binary.BigEndian.Uint16(p[2:]) +} + +func parsePacket(p []byte) (interface{}, error) { + l := len(p) + if l < 2 { + return nil, fmt.Errorf("short packet") + } + opcode := binary.BigEndian.Uint16(p) + switch opcode { + case opRRQ: + if l < 4 { + return nil, fmt.Errorf("short RRQ packet: %d", l) + } + return pRRQ(p), nil + case opWRQ: + if l < 4 { + return nil, fmt.Errorf("short WRQ packet: %d", l) + } + return pWRQ(p), nil + case opDATA: + if l < 4 { + return nil, fmt.Errorf("short DATA packet: %d", l) + } + return pDATA(p), nil + case opACK: + if l < 4 { + return nil, fmt.Errorf("short ACK packet: %d", l) + } + return pACK(p), nil + case opERROR: + if l < 5 { + return nil, fmt.Errorf("short ERROR packet: %d", l) + } + return pERROR(p), nil + case opOACK: + if l < 6 { + return nil, fmt.Errorf("short OACK packet: %d", l) + } + return pOACK(p), nil + default: + return nil, fmt.Errorf("unknown opcode: %d", opcode) + } +} diff --git a/vendor/github.com/pin/tftp/receiver.go b/vendor/github.com/pin/tftp/receiver.go new file mode 100644 index 00000000..6e0153d2 --- /dev/null +++ b/vendor/github.com/pin/tftp/receiver.go @@ -0,0 +1,234 @@ +package tftp + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/pin/tftp/netascii" +) + +// IncomingTransfer provides methods that expose information associated with +// an incoming transfer. +type IncomingTransfer interface { + // Size returns the size of an incoming file if the request included the + // tsize option (see RFC2349). To differentiate a zero-sized file transfer + // from a request without tsize use the second boolean "ok" return value. + Size() (n int64, ok bool) + + // RemoteAddr returns the remote peer's IP address and port. + RemoteAddr() net.UDPAddr +} + +func (r *receiver) RemoteAddr() net.UDPAddr { return *r.addr } + +func (r *receiver) Size() (n int64, ok bool) { + if r.opts != nil { + if s, ok := r.opts["tsize"]; ok { + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, false + } + return n, true + } + } + return 0, false +} + +type receiver struct { + send []byte + receive []byte + addr *net.UDPAddr + tid int + conn *net.UDPConn + block uint16 + retry *backoff + timeout time.Duration + retries int + l int + autoTerm bool + dally bool + mode string + opts options +} + +func (r *receiver) WriteTo(w io.Writer) (n int64, err error) { + if r.mode == "netascii" { + w = netascii.FromWriter(w) + } + if r.opts != nil { + err := r.sendOptions() + if err != nil { + r.abort(err) + return 0, err + } + } + binary.BigEndian.PutUint16(r.send[0:2], opACK) + for { + if r.l > 0 { + l, err := w.Write(r.receive[4:r.l]) + n += int64(l) + if err != nil { + r.abort(err) + return n, err + } + if r.l < len(r.receive) { + if r.autoTerm { + r.terminate() + r.conn.Close() + } + return n, nil + } + } + binary.BigEndian.PutUint16(r.send[2:4], r.block) + r.block++ // send ACK for current block and expect next one + ll, _, err := r.receiveWithRetry(4) + if err != nil { + r.abort(err) + return n, err + } + r.l = ll + } +} + +func (r *receiver) sendOptions() error { + for name, value := range r.opts { + if name == "blksize" { + err := r.setBlockSize(value) + if err != nil { + delete(r.opts, name) + continue + } + } else { + delete(r.opts, name) + } + } + if len(r.opts) > 0 { + m := packOACK(r.send, r.opts) + r.block = 1 // expect data block number 1 + ll, _, err := r.receiveWithRetry(m) + if err != nil { + r.abort(err) + return err + } + r.l = ll + } + return nil +} + +func (r *receiver) setBlockSize(blksize string) error { + n, err := strconv.Atoi(blksize) + if err != nil { + return err + } + if n < 512 { + return fmt.Errorf("blkzise too small: %d", n) + } + if n > 65464 { + return fmt.Errorf("blksize too large: %d", n) + } + r.receive = make([]byte, n+4) + return nil +} + +func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) { + r.retry.reset() + for { + n, addr, err := r.receiveDatagram(l) + if _, ok := err.(net.Error); ok && r.retry.count() < r.retries { + r.retry.backoff() + continue + } + return n, addr, err + } +} + +func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) { + err := r.conn.SetReadDeadline(time.Now().Add(r.timeout)) + if err != nil { + return 0, nil, err + } + _, err = r.conn.WriteToUDP(r.send[:l], r.addr) + if err != nil { + return 0, nil, err + } + for { + c, addr, err := r.conn.ReadFromUDP(r.receive) + if err != nil { + return 0, nil, err + } + if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) { + continue + } + p, err := parsePacket(r.receive[:c]) + if err != nil { + return 0, addr, err + } + r.tid = addr.Port + switch p := p.(type) { + case pDATA: + if p.block() == r.block { + return c, addr, nil + } + case pOACK: + opts, err := unpackOACK(p) + if r.block != 1 { + continue + } + if err != nil { + r.abort(err) + return 0, addr, err + } + for name, value := range opts { + if name == "blksize" { + err := r.setBlockSize(value) + if err != nil { + continue + } + } + } + r.block = 0 // ACK with block number 0 + r.opts = opts + return 0, addr, nil + case pERROR: + return 0, addr, fmt.Errorf("code: %d, message: %s", + p.code(), p.message()) + } + } +} + +func (r *receiver) terminate() error { + binary.BigEndian.PutUint16(r.send[2:4], r.block) + if r.dally { + for i := 0; i < 3; i++ { + _, _, err := r.receiveDatagram(4) + if err != nil { + return nil + } + } + return fmt.Errorf("dallying termination failed") + } else { + _, err := r.conn.WriteToUDP(r.send[:4], r.addr) + if err != nil { + return err + } + } + return nil +} + +func (r *receiver) abort(err error) error { + if r.conn == nil { + return nil + } + n := packERROR(r.send, 1, err.Error()) + _, err = r.conn.WriteToUDP(r.send[:n], r.addr) + if err != nil { + return err + } + r.conn.Close() + r.conn = nil + return nil +} diff --git a/vendor/github.com/pin/tftp/sender.go b/vendor/github.com/pin/tftp/sender.go new file mode 100644 index 00000000..d018c4f0 --- /dev/null +++ b/vendor/github.com/pin/tftp/sender.go @@ -0,0 +1,243 @@ +package tftp + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/pin/tftp/netascii" +) + +// OutgoingTransfer provides methods to set the outgoing transfer size and +// retrieve the remote address of the peer. +type OutgoingTransfer interface { + // SetSize is used to set the outgoing transfer size (tsize option: RFC2349) + // manually in a server write transfer handler. + // + // It is not necessary in most cases; when the io.Reader provided to + // ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will + // be determined automatically. Seek will not be attempted when the + // transfer size option is set with SetSize. + // + // The value provided will be used only if SetSize is called before ReadFrom + // and only on in a server read handler. + SetSize(n int64) + + // RemoteAddr returns the remote peer's IP address and port. + RemoteAddr() net.UDPAddr +} + +type sender struct { + conn *net.UDPConn + addr *net.UDPAddr + tid int + send []byte + receive []byte + retry *backoff + timeout time.Duration + retries int + block uint16 + mode string + opts options +} + +func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr } + +func (s *sender) SetSize(n int64) { + if s.opts != nil { + if _, ok := s.opts["tsize"]; ok { + s.opts["tsize"] = strconv.FormatInt(n, 10) + } + } +} + +func (s *sender) ReadFrom(r io.Reader) (n int64, err error) { + if s.mode == "netascii" { + r = netascii.ToReader(r) + } + if s.opts != nil { + // check that tsize is set + if ts, ok := s.opts["tsize"]; ok { + // check that tsize is not set with SetSize already + i, err := strconv.ParseInt(ts, 10, 64) + if err == nil && i == 0 { + if rs, ok := r.(io.Seeker); ok { + pos, err := rs.Seek(0, 1) + if err != nil { + return 0, err + } + size, err := rs.Seek(0, 2) + if err != nil { + return 0, err + } + s.opts["tsize"] = strconv.FormatInt(size, 10) + _, err = rs.Seek(pos, 0) + if err != nil { + return 0, err + } + } + } + } + err = s.sendOptions() + if err != nil { + s.abort(err) + return 0, err + } + } + s.block = 1 // start data transmission with block 1 + binary.BigEndian.PutUint16(s.send[0:2], opDATA) + for { + l, err := io.ReadFull(r, s.send[4:]) + n += int64(l) + if err != nil && err != io.ErrUnexpectedEOF { + if err == io.EOF { + binary.BigEndian.PutUint16(s.send[2:4], s.block) + _, err = s.sendWithRetry(4) + if err != nil { + s.abort(err) + return n, err + } + s.conn.Close() + return n, nil + } + s.abort(err) + return n, err + } + binary.BigEndian.PutUint16(s.send[2:4], s.block) + _, err = s.sendWithRetry(4 + l) + if err != nil { + s.abort(err) + return n, err + } + if l < len(s.send)-4 { + s.conn.Close() + return n, nil + } + s.block++ + } +} + +func (s *sender) sendOptions() error { + for name, value := range s.opts { + if name == "blksize" { + err := s.setBlockSize(value) + if err != nil { + delete(s.opts, name) + continue + } + } else if name == "tsize" { + if value != "0" { + s.opts["tsize"] = value + } else { + delete(s.opts, name) + continue + } + } else { + delete(s.opts, name) + } + } + if len(s.opts) > 0 { + m := packOACK(s.send, s.opts) + _, err := s.sendWithRetry(m) + if err != nil { + return err + } + } + return nil +} + +func (s *sender) setBlockSize(blksize string) error { + n, err := strconv.Atoi(blksize) + if err != nil { + return err + } + if n < 512 { + return fmt.Errorf("blkzise too small: %d", n) + } + if n > 65464 { + return fmt.Errorf("blksize too large: %d", n) + } + s.send = make([]byte, n+4) + return nil +} + +func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) { + s.retry.reset() + for { + addr, err := s.sendDatagram(l) + if _, ok := err.(net.Error); ok && s.retry.count() < s.retries { + s.retry.backoff() + continue + } + return addr, err + } +} + +func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) { + err := s.conn.SetReadDeadline(time.Now().Add(s.timeout)) + if err != nil { + return nil, err + } + _, err = s.conn.WriteToUDP(s.send[:l], s.addr) + if err != nil { + return nil, err + } + for { + n, addr, err := s.conn.ReadFromUDP(s.receive) + if err != nil { + return nil, err + } + if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) { + continue + } + p, err := parsePacket(s.receive[:n]) + if err != nil { + continue + } + s.tid = addr.Port + switch p := p.(type) { + case pACK: + if p.block() == s.block { + return addr, nil + } + case pOACK: + opts, err := unpackOACK(p) + if s.block != 0 { + continue + } + if err != nil { + s.abort(err) + return addr, err + } + for name, value := range opts { + if name == "blksize" { + err := s.setBlockSize(value) + if err != nil { + continue + } + } + } + return addr, nil + case pERROR: + return nil, fmt.Errorf("sending block %d: code=%d, error: %s", + s.block, p.code(), p.message()) + } + } +} + +func (s *sender) abort(err error) error { + if s.conn == nil { + return nil + } + n := packERROR(s.send, 1, err.Error()) + _, err = s.conn.WriteToUDP(s.send[:n], s.addr) + if err != nil { + return err + } + s.conn.Close() + s.conn = nil + return nil +} diff --git a/vendor/github.com/pin/tftp/server.go b/vendor/github.com/pin/tftp/server.go new file mode 100644 index 00000000..755119bf --- /dev/null +++ b/vendor/github.com/pin/tftp/server.go @@ -0,0 +1,199 @@ +package tftp + +import ( + "fmt" + "io" + "net" + "sync" + "time" +) + +// NewServer creates TFTP server. It requires two functions to handle +// read and write requests. +// In case nil is provided for read or write handler the respective +// operation is disabled. +func NewServer(readHandler func(filename string, rf io.ReaderFrom) error, + writeHandler func(filename string, wt io.WriterTo) error) *Server { + return &Server{ + readHandler: readHandler, + writeHandler: writeHandler, + timeout: defaultTimeout, + retries: defaultRetries, + } +} + +type Server struct { + readHandler func(filename string, rf io.ReaderFrom) error + writeHandler func(filename string, wt io.WriterTo) error + backoff backoffFunc + conn *net.UDPConn + quit chan chan struct{} + wg sync.WaitGroup + timeout time.Duration + retries int +} + +// SetTimeout sets maximum time server waits for single network +// round-trip to succeed. +// Default is 5 seconds. +func (s *Server) SetTimeout(t time.Duration) { + if t <= 0 { + s.timeout = defaultTimeout + } else { + s.timeout = t + } +} + +// SetRetries sets maximum number of attempts server made to transmit a +// packet. +// Default is 5 attempts. +func (s *Server) SetRetries(count int) { + if count < 1 { + s.retries = defaultRetries + } else { + s.retries = count + } +} + +// SetBackoff sets a user provided function that is called to provide a +// backoff duration prior to retransmitting an unacknowledged packet. +func (s *Server) SetBackoff(h backoffFunc) { + s.backoff = h +} + +// ListenAndServe binds to address provided and start the server. +// ListenAndServe returns when Shutdown is called. +func (s *Server) ListenAndServe(addr string) error { + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + conn, err := net.ListenUDP("udp", a) + if err != nil { + return err + } + s.Serve(conn) + return nil +} + +// Serve starts server provided already opened UDP connecton. It is +// useful for the case when you want to run server in separate goroutine +// but still want to be able to handle any errors opening connection. +// Serve returns when Shutdown is called or connection is closed. +func (s *Server) Serve(conn *net.UDPConn) { + s.conn = conn + s.quit = make(chan chan struct{}) + for { + select { + case q := <-s.quit: + q <- struct{}{} + return + default: + err := s.processRequest(s.conn) + if err != nil { + // TODO: add logging handler + } + } + } +} + +// Shutdown make server stop listening for new requests, allows +// server to finish outstanding transfers and stops server. +func (s *Server) Shutdown() { + s.conn.Close() + q := make(chan struct{}) + s.quit <- q + <-q + s.wg.Wait() +} + +func (s *Server) processRequest(conn *net.UDPConn) error { + var buffer []byte + buffer = make([]byte, datagramLength) + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + return fmt.Errorf("reading UDP: %v", err) + } + p, err := parsePacket(buffer[:n]) + if err != nil { + return err + } + switch p := p.(type) { + case pWRQ: + filename, mode, opts, err := unpackRQ(p) + if err != nil { + return fmt.Errorf("unpack WRQ: %v", err) + } + //fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return err + } + if err != nil { + return fmt.Errorf("open transmission: %v", err) + } + wt := &receiver{ + send: make([]byte, datagramLength), + receive: make([]byte, datagramLength), + conn: conn, + retry: &backoff{handler: s.backoff}, + timeout: s.timeout, + retries: s.retries, + addr: remoteAddr, + mode: mode, + opts: opts, + } + s.wg.Add(1) + go func() { + if s.writeHandler != nil { + err := s.writeHandler(filename, wt) + if err != nil { + wt.abort(err) + } else { + wt.terminate() + wt.conn.Close() + } + } else { + wt.abort(fmt.Errorf("server does not support write requests")) + } + s.wg.Done() + }() + case pRRQ: + filename, mode, opts, err := unpackRQ(p) + if err != nil { + return fmt.Errorf("unpack RRQ: %v", err) + } + //fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return err + } + rf := &sender{ + send: make([]byte, datagramLength), + receive: make([]byte, datagramLength), + tid: remoteAddr.Port, + conn: conn, + retry: &backoff{handler: s.backoff}, + timeout: s.timeout, + retries: s.retries, + addr: remoteAddr, + mode: mode, + opts: opts, + } + s.wg.Add(1) + go func() { + if s.readHandler != nil { + err := s.readHandler(filename, rf) + if err != nil { + rf.abort(err) + } + } else { + rf.abort(fmt.Errorf("server does not support read requests")) + } + s.wg.Done() + }() + default: + return fmt.Errorf("unexpected %T", p) + } + return nil +} diff --git a/vendor/github.com/pin/tftp/tftp_test.go b/vendor/github.com/pin/tftp/tftp_test.go new file mode 100644 index 00000000..f77ae3da --- /dev/null +++ b/vendor/github.com/pin/tftp/tftp_test.go @@ -0,0 +1,622 @@ +package tftp + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net" + "os" + "strconv" + "sync" + "testing" + "testing/iotest" + "time" +) + +var localhost string = determineLocalhost() + +func determineLocalhost() string { + l, err := net.ListenTCP("tcp", nil) + if err != nil { + panic(fmt.Sprintf("ListenTCP error: %s", err)) + } + _, lport, _ := net.SplitHostPort(l.Addr().String()) + defer l.Close() + + lo := make(chan string) + + go func() { + for { + conn, err := l.Accept() + if err != nil { + break + } + conn.Close() + } + }() + + go func() { + port, _ := strconv.Atoi(lport) + for _, af := range []string{"tcp6", "tcp4"} { + conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) + if err == nil { + conn.Close() + host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) + lo <- host + return + } + } + panic("could not determine address family") + }() + + return <-lo +} + +func localSystem(c *net.UDPConn) string { + _, port, _ := net.SplitHostPort(c.LocalAddr().String()) + return net.JoinHostPort(localhost, port) +} + +func TestPackUnpack(t *testing.T) { + v := []string{"test-filename/with-subdir"} + testOptsList := []options{ + nil, + options{ + "tsize": "1234", + "blksize": "22", + }, + } + for _, filename := range v { + for _, mode := range []string{"octet", "netascii"} { + for _, opts := range testOptsList { + packUnpack(t, filename, mode, opts) + } + } + } +} + +func packUnpack(t *testing.T, filename, mode string, opts options) { + b := make([]byte, datagramLength) + for _, op := range []uint16{opRRQ, opWRQ} { + n := packRQ(b, op, filename, mode, opts) + f, m, o, err := unpackRQ(b[:n]) + if err != nil { + t.Errorf("%s pack/unpack: %v", filename, err) + } + if f != filename { + t.Errorf("filename mismatch (%s): '%x' vs '%x'", + filename, f, filename) + } + if m != mode { + t.Errorf("mode mismatch (%s): '%x' vs '%x'", + mode, m, mode) + } + if opts != nil { + for name, value := range opts { + v, ok := o[name] + if !ok { + t.Errorf("missing %s option", name) + } + if v != value { + t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) + } + } + } + } +} + +func TestZeroLength(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + testSendReceive(t, c, 0) +} + +func Test900(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + for i := 600; i < 4000; i += 1 { + c.blksize = i + testSendReceive(t, c, 9000+int64(i)) + } +} + +func Test1000(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + for i := int64(0); i < 5000; i++ { + filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) + rf, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("requesting %s write: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(i)), i) + n, err := rf.ReadFrom(r) + if err != nil { + t.Fatalf("sending %s: %v", filename, err) + } + if n != i { + t.Errorf("%s length mismatch: %d != %d", filename, n, i) + } + } +} + +func Test1810(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + c.blksize = 1810 + testSendReceive(t, c, 9000+1810) +} + +func TestTSize(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + c.tsize = true + testSendReceive(t, c, 640) +} + +func TestNearBlockLength(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + for i := 450; i < 520; i++ { + testSendReceive(t, c, int64(i)) + } +} + +func TestBlockWrapsAround(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + n := 65535 * 512 + for i := n - 2; i < n+2; i++ { + testSendReceive(t, c, int64(i)) + } +} + +func TestRandomLength(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + r := rand.New(rand.NewSource(42)) + for i := 0; i < 100; i++ { + testSendReceive(t, c, r.Int63n(100000)) + } +} + +func TestBigFile(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + testSendReceive(t, c, 3*1000*1000) +} + +func TestByOneByte(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + filename := "test-by-one-byte" + mode := "octet" + const length = 80000 + sender, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + r := iotest.OneByteReader(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + n, err := sender.ReadFrom(r) + if err != nil { + t.Fatalf("send error: %v", err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := ioutil.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } +} + +func TestDuplicate(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + filename := "test-duplicate" + mode := "octet" + bs := []byte("lalala") + sender, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write: %v", err) + } + buf := bytes.NewBuffer(bs) + _, err = sender.ReadFrom(buf) + if err != nil { + t.Fatalf("send error: %v", err) + } + sender, err = c.Send(filename, mode) + if err == nil { + t.Fatalf("file already exists") + } + t.Logf("sending file that already exists: %v", err) +} + +func TestNotFound(t *testing.T) { + s, c := makeTestServer() + defer s.Shutdown() + filename := "test-not-exists" + mode := "octet" + _, err := c.Receive(filename, mode) + if err == nil { + t.Fatalf("file not exists", err) + } + t.Logf("receiving file that does not exist: %v", err) +} + +func testSendReceive(t *testing.T, client *Client, length int64) { + filename := fmt.Sprintf("length-%d-bytes", length) + mode := "octet" + writeTransfer, err := client.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), length) + n, err := writeTransfer.ReadFrom(r) + if err != nil { + t.Fatalf("%s write error: %v", filename, err) + } + if n != length { + t.Errorf("%s write length mismatch: %d != %d", filename, n, length) + } + readTransfer, err := client.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + if it, ok := readTransfer.(IncomingTransfer); ok { + if n, ok := it.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + if n != length { + t.Errorf("tsize mismatch: %d vs %d", n, length) + } + } + } + buf := &bytes.Buffer{} + n, err = readTransfer.WriteTo(buf) + if err != nil { + t.Fatalf("%s read error: %v", filename, err) + } + if n != length { + t.Errorf("%s read length mismatch: %d != %d", filename, n, length) + } + bs, _ := ioutil.ReadAll(io.LimitReader( + newRandReader(rand.NewSource(42)), length)) + if !bytes.Equal(bs, buf.Bytes()) { + t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) + } +} + +func TestSendTsizeFromSeek(t *testing.T) { + // create read-only server + s := NewServer(func(filename string, rf io.ReaderFrom) error { + b := make([]byte, 100) + rr := newRandReader(rand.NewSource(42)) + rr.Read(b) + // bytes.Reader implements io.Seek + r := bytes.NewReader(b) + _, err := rf.ReadFrom(r) + if err != nil { + t.Errorf("sending bytes: %v", err) + } + return nil + }, nil) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + t.Fatalf("listening: %v", err) + } + + go s.Serve(conn) + defer s.Shutdown() + + c, _ := NewClient(localSystem(conn)) + c.tsize = true + r, _ := c.Receive("f", "octet") + var size int64 + if t, ok := r.(IncomingTransfer); ok { + if n, ok := t.Size(); ok { + size = n + fmt.Printf("Transfer size: %d\n", n) + } + } + + if size != 100 { + t.Errorf("size expected: 100, got %d", size) + } + + r.WriteTo(ioutil.Discard) +} + +type testBackend struct { + m map[string][]byte + mu sync.Mutex +} + +func makeTestServer() (*Server, *Client) { + b := &testBackend{} + b.m = make(map[string][]byte) + + // Create server + s := NewServer(b.handleRead, b.handleWrite) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + panic(err) + } + + go s.Serve(conn) + + // Create client for that server + c, err := NewClient(localSystem(conn)) + if err != nil { + panic(err) + } + + return s, c +} + +func TestNoHandlers(t *testing.T) { + s := NewServer(nil, nil) + + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + panic(err) + } + + go s.Serve(conn) + + c, err := NewClient(localSystem(conn)) + if err != nil { + panic(err) + } + + _, err = c.Send("test", "octet") + if err == nil { + t.Errorf("error expected") + } + + _, err = c.Receive("test", "octet") + if err == nil { + t.Errorf("error expected") + } +} + +func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { + b.mu.Lock() + defer b.mu.Unlock() + _, ok := b.m[filename] + if ok { + fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) + return fmt.Errorf("file already exists") + } + if t, ok := wt.(IncomingTransfer); ok { + if n, ok := t.Size(); ok { + fmt.Printf("Transfer size: %d\n", n) + } + } + buf := &bytes.Buffer{} + _, err := wt.WriteTo(buf) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) + return err + } + b.m[filename] = buf.Bytes() + return nil +} + +func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { + b.mu.Lock() + defer b.mu.Unlock() + bs, ok := b.m[filename] + if !ok { + fmt.Fprintf(os.Stderr, "File %s not found\n", filename) + return fmt.Errorf("file not found") + } + if t, ok := rf.(OutgoingTransfer); ok { + t.SetSize(int64(len(bs))) + } + _, err := rf.ReadFrom(bytes.NewBuffer(bs)) + if err != nil { + fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) + return err + } + return nil +} + +type randReader struct { + src rand.Source + next int64 + i int8 +} + +func newRandReader(src rand.Source) io.Reader { + r := &randReader{ + src: src, + next: src.Int63(), + } + return r +} + +func (r *randReader) Read(p []byte) (n int, err error) { + next, i := r.next, r.i + for n = 0; n < len(p); n++ { + if i == 7 { + next, i = r.src.Int63(), 0 + } + p[n] = byte(next) + next >>= 8 + i++ + } + r.next, r.i = next, i + return +} + +func TestServerSendTimeout(t *testing.T) { + s, c := makeTestServer() + s.SetTimeout(time.Second) + s.SetRetries(2) + var serverErr error + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, serverErr = rf.ReadFrom(r) + return serverErr + } + defer s.Shutdown() + filename := "test-server-send-timeout" + mode := "octet" + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, _ = readTransfer.WriteTo(w) + netErr, ok := serverErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", serverErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", serverErr) + } +} + +func TestServerReceiveTimeout(t *testing.T) { + s, c := makeTestServer() + s.SetTimeout(time.Second) + s.SetRetries(2) + var serverErr error + s.writeHandler = func(filename string, wt io.WriterTo) error { + buf := &bytes.Buffer{} + _, serverErr = wt.WriteTo(buf) + return serverErr + } + defer s.Shutdown() + filename := "test-server-receive-timeout" + mode := "octet" + writeTransfer, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, _ = writeTransfer.ReadFrom(r) + netErr, ok := serverErr.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", serverErr) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", serverErr) + } +} + +func TestClientReceiveTimeout(t *testing.T) { + s, c := makeTestServer() + c.SetTimeout(time.Second) + c.SetRetries(2) + s.readHandler = func(filename string, rf io.ReaderFrom) error { + r := &slowReader{ + r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), + n: 3, + delay: 8 * time.Second, + } + _, err := rf.ReadFrom(r) + return err + } + defer s.Shutdown() + filename := "test-client-receive-timeout" + mode := "octet" + readTransfer, err := c.Receive(filename, mode) + if err != nil { + t.Fatalf("requesting read %s: %v", filename, err) + } + buf := &bytes.Buffer{} + _, err = readTransfer.WriteTo(buf) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } +} + +func TestClientSendTimeout(t *testing.T) { + s, c := makeTestServer() + c.SetTimeout(time.Second) + c.SetRetries(2) + s.writeHandler = func(filename string, wt io.WriterTo) error { + w := &slowWriter{ + n: 3, + delay: 8 * time.Second, + } + _, err := wt.WriteTo(w) + return err + } + defer s.Shutdown() + filename := "test-client-send-timeout" + mode := "octet" + writeTransfer, err := c.Send(filename, mode) + if err != nil { + t.Fatalf("requesting write %s: %v", filename, err) + } + r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) + _, err = writeTransfer.ReadFrom(r) + netErr, ok := err.(net.Error) + if !ok { + t.Fatalf("network error expected: %T", err) + } + if !netErr.Timeout() { + t.Fatalf("timout is expected: %v", err) + } +} + +type slowReader struct { + r io.Reader + n int64 + delay time.Duration +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return r.r.Read(p) + } + time.Sleep(r.delay) + return r.r.Read(p) +} + +type slowWriter struct { + r io.Reader + n int64 + delay time.Duration +} + +func (r *slowWriter) Write(p []byte) (n int, err error) { + if r.n > 0 { + r.n-- + return len(p), nil + } + time.Sleep(r.delay) + return len(p), nil +}