1
0
mirror of https://github.com/rancher/os.git synced 2025-07-15 23:55:51 +00:00

Added tftp datasource for cloud config.

This commit is contained in:
Morten Møller Riis 2019-03-12 11:37:08 +01:00 committed by niusmallnan
parent 8b75752225
commit 66c5f6130a
18 changed files with 2252 additions and 0 deletions

View File

@ -38,6 +38,7 @@ import (
"github.com/rancher/os/config/cloudinit/datasource/metadata/gce"
"github.com/rancher/os/config/cloudinit/datasource/metadata/packet"
"github.com/rancher/os/config/cloudinit/datasource/proccmdline"
"github.com/rancher/os/config/cloudinit/datasource/tftp"
"github.com/rancher/os/config/cloudinit/datasource/url"
"github.com/rancher/os/config/cloudinit/datasource/vmware"
"github.com/rancher/os/config/cloudinit/pkg"
@ -237,6 +238,8 @@ func getDatasources(datasources []string) []datasource.Datasource {
if root != "" {
dss = append(dss, file.NewDatasource(root))
}
case "tftp":
dss = append(dss, tftp.NewDatasource(root))
case "url":
if root != "" {
dss = append(dss, url.NewDatasource(root))

View File

@ -0,0 +1,83 @@
package tftp
import (
"bytes"
"fmt"
"io"
"regexp"
"strings"
"github.com/rancher/os/config/cloudinit/datasource"
"github.com/pin/tftp"
)
type Client interface {
Receive(filename string, mode string) (io.WriterTo, error)
}
type RemoteFile struct {
host string
path string
client Client
stream io.WriterTo
lastError error
}
func NewDatasource(hostAndPath string) *RemoteFile {
parts := strings.SplitN(hostAndPath, "/", 2)
if len(parts) < 2 {
return &RemoteFile{hostAndPath, "", nil, nil, nil}
}
host := parts[0]
if match, _ := regexp.MatchString(":[0-9]{2,5}$", host); !match {
// No port, using default port 69
host += ":69"
}
path := parts[1]
if client, lastError := tftp.NewClient(host); lastError == nil {
return &RemoteFile{host, path, client, nil, nil}
}
return &RemoteFile{host, path, nil, nil, nil}
}
func (f *RemoteFile) IsAvailable() bool {
f.stream, f.lastError = f.client.Receive(f.path, "octet")
return f.lastError == nil
}
func (f *RemoteFile) Finish() error {
return nil
}
func (f *RemoteFile) String() string {
return fmt.Sprintf("%s: %s%s (lastError: %v)", f.Type(), f.host, f.path, f.lastError)
}
func (f *RemoteFile) AvailabilityChanges() bool {
return false
}
func (f *RemoteFile) ConfigRoot() string {
return ""
}
func (f *RemoteFile) FetchMetadata() (datasource.Metadata, error) {
return datasource.Metadata{}, nil
}
func (f *RemoteFile) FetchUserdata() ([]byte, error) {
var b bytes.Buffer
_, err := f.stream.WriteTo(&b)
return b.Bytes(), err
}
func (f *RemoteFile) Type() string {
return "tftp"
}

View File

@ -0,0 +1,92 @@
package tftp
import (
"fmt"
"io"
"reflect"
"testing"
)
type mockClient struct {
}
type mockReceiver struct {
}
func (r mockReceiver) WriteTo(w io.Writer) (n int64, err error) {
b := []byte("cloud-config file")
w.Write(b)
return int64(len(b)), nil
}
func (c mockClient) Receive(filename string, mode string) (io.WriterTo, error) {
if filename == "does-not-exist" {
return &mockReceiver{}, fmt.Errorf("does not exist")
}
return &mockReceiver{}, nil
}
var _ Client = (*mockClient)(nil)
func TestNewDatasource(t *testing.T) {
for _, tt := range []struct {
root string
expectHost string
expectPath string
}{
{
root: "127.0.0.1/test/file.yaml",
expectHost: "127.0.0.1:69",
expectPath: "test/file.yaml",
},
{
root: "127.0.0.1/test/file.yaml",
expectHost: "127.0.0.1:69",
expectPath: "test/file.yaml",
},
} {
ds := NewDatasource(tt.root)
if ds.host != tt.expectHost || ds.path != tt.expectPath {
t.Fatalf("bad host or path (%q): want host=%s, got %s, path=%s, got %s", tt.root, tt.expectHost, ds.host, tt.expectPath, ds.path)
}
}
}
func TestIsAvailable(t *testing.T) {
for _, tt := range []struct {
remoteFile *RemoteFile
expect bool
}{
{
remoteFile: &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil},
expect: true,
},
{
remoteFile: &RemoteFile{"1.2.3.4", "does-not-exist", &mockClient{}, nil, nil},
expect: false,
},
} {
if tt.remoteFile.IsAvailable() != tt.expect {
t.Fatalf("expected remote file %s to be %v", tt.remoteFile.path, tt.expect)
}
}
}
func TestFetchUserdata(t *testing.T) {
rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, &mockReceiver{}, nil}
b, _ := rf.FetchUserdata()
expect := []byte("cloud-config file")
if len(b) != len(expect) || !reflect.DeepEqual(b, expect) {
t.Fatalf("expected length of buffer to be %d was %d. Expected %s, got %s", len(expect), len(b), string(expect), string(b))
}
}
func TestType(t *testing.T) {
rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil}
if rf.Type() != "tftp" {
t.Fatalf("expected remote file Type() to return %s got %s", "tftp", rf.Type())
}
}

View File

@ -57,3 +57,4 @@ golang.org/x/sys eb2c74142fd19a79b3f237334c7384d5167b1b46 https://github.com/gol
google.golang.org/grpc ab0be5212fb225475f2087566eded7da5d727960 https://github.com/grpc/grpc-go.git
gopkg.in/fsnotify.v1 v1.2.0
github.com/fatih/structs dc3312cb1a4513a366c4c9e622ad55c32df12ed3
github.com/pin/tftp v2.1.0

24
vendor/github.com/pin/tftp/.gitignore generated vendored Normal file
View File

@ -0,0 +1,24 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof

8
vendor/github.com/pin/tftp/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,8 @@
language: go
os:
- linux
- osx
before_install:
- ulimit -n 4096

4
vendor/github.com/pin/tftp/CONTRIBUTORS generated vendored Normal file
View File

@ -0,0 +1,4 @@
Dmitri Popov
Mojo Talantikite
Giovanni Bajo
Andrew Danforth

21
vendor/github.com/pin/tftp/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Dmitri Popov
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

171
vendor/github.com/pin/tftp/README.md generated vendored Normal file
View File

@ -0,0 +1,171 @@
TFTP server and client library for Golang
=========================================
[![GoDoc](https://godoc.org/github.com/pin/tftp?status.svg)](https://godoc.org/github.com/pin/tftp)
[![Build Status](https://travis-ci.org/pin/tftp.svg?branch=master)](https://travis-ci.org/pin/tftp)
Implements:
* [RFC 1350](https://tools.ietf.org/html/rfc1350) - The TFTP Protocol (Revision 2)
* [RFC 2347](https://tools.ietf.org/html/rfc2347) - TFTP Option Extension
* [RFC 2348](https://tools.ietf.org/html/rfc2348) - TFTP Blocksize Option
Partially implements (tsize server side only):
* [RFC 2349](https://tools.ietf.org/html/rfc2349) - TFTP Timeout Interval and Transfer Size Options
Set of features is sufficient for PXE boot support.
``` go
import "github.com/pin/tftp"
```
The package is cohesive to Golang `io`. Particularly it implements
`io.ReaderFrom` and `io.WriterTo` interfaces. That allows efficient data
transmission without unnecessary memory copying and allocations.
TFTP Server
-----------
```go
// readHandler is called when client starts file download from server
func readHandler(filename string, rf io.ReaderFrom) error {
file, err := os.Open(filename)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
n, err := rf.ReadFrom(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
fmt.Printf("%d bytes sent\n", n)
return nil
}
// writeHandler is called when client starts file upload to server
func writeHandler(filename string, wt io.WriterTo) error {
file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
n, err := wt.WriteTo(file)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
return err
}
fmt.Printf("%d bytes received\n", n)
return nil
}
func main() {
// use nil in place of handler to disable read or write operations
s := tftp.NewServer(readHandler, writeHandler)
s.SetTimeout(5 * time.Second) // optional
err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called
if err != nil {
fmt.Fprintf(os.Stdout, "server: %v\n", err)
os.Exit(1)
}
}
```
TFTP Client
-----------
Upload file to server:
```go
c, err := tftp.NewClient("172.16.4.21:69")
file, err := os.Open(path)
c.SetTimeout(5 * time.Second) // optional
rf, err := c.Send("foobar.txt", "octet")
n, err := rf.ReadFrom(file)
fmt.Printf("%d bytes sent\n", n)
```
Download file from server:
```go
c, err := tftp.NewClient("172.16.4.21:69")
wt, err := c.Receive("foobar.txt", "octet")
file, err := os.Create(path)
// Optionally obtain transfer size before actual data.
if n, ok := wt.(IncomingTransfer).Size(); ok {
fmt.Printf("Transfer size: %d\n", n)
}
n, err := wt.WriteTo(file)
fmt.Printf("%d bytes received\n", n)
```
Note: please handle errors better :)
TSize option
------------
PXE boot ROM often expects tsize option support from a server: client
(e.g. computer that boots over the network) wants to know size of a
download before the actual data comes. Server has to obtain stream
size and send it to a client.
Often it will happen automatically because TFTP library tries to check
if `io.Reader` provided to `ReadFrom` method also satisfies
`io.Seeker` interface (`os.File` for instance) and uses `Seek` to
determine file size.
In case `io.Reader` you provide to `ReadFrom` in read handler does not
satisfy `io.Seeker` interface or you do not want TFTP library to call
`Seek` on your reader but still want to respond with tsize option
during outgoing request you can use an `OutgoingTransfer` interface:
```go
func readHandler(filename string, rf io.ReaderFrom) error {
...
// Set transfer size before calling ReadFrom.
rf.(tftp.OutgoingTransfer).SetSize(myFileSize)
...
// ReadFrom ...
```
Similarly, it is possible to obtain size of a file that is about to be
received using `IncomingTransfer` interface (see `Size` method).
Remote Address
--------------
The `OutgoingTransfer` and `IncomingTransfer` interfaces also provide the
`RemoteAddr` method which returns the peer IP address and port as a
`net.UDPAddr`. This can be used for detailed logging in a server handler.
```go
func readHandler(filename string, rf io.ReaderFrom) error {
...
raddr := rf.(tftp.OutgoingTransfer).RemoteAddr()
log.Println("RRQ from", raddr.String())
...
// ReadFrom ...
```
Backoff
-------
The default backoff before retransmitting an unacknowledged packet is a
random duration between 0 and 1 second. This behavior can be overridden
in clients and servers by providing a custom backoff calculation function.
```go
s := tftp.NewServer(readHandler, writeHandler)
s.SetBackoff(func (attempts int) time.Duration {
return time.Duration(attempts) * time.Second
})
```
or, for no backoff
```go
s.SetBackoff(func (int) time.Duration { return 0 })
```

35
vendor/github.com/pin/tftp/backoff.go generated vendored Normal file
View File

@ -0,0 +1,35 @@
package tftp
import (
"math/rand"
"time"
)
const (
defaultTimeout = 5 * time.Second
defaultRetries = 5
)
type backoffFunc func(int) time.Duration
type backoff struct {
attempt int
handler backoffFunc
}
func (b *backoff) reset() {
b.attempt = 0
}
func (b *backoff) count() int {
return b.attempt
}
func (b *backoff) backoff() {
if b.handler == nil {
time.Sleep(time.Duration(rand.Int63n(int64(time.Second))))
} else {
time.Sleep(b.handler(b.attempt))
}
b.attempt++
}

125
vendor/github.com/pin/tftp/client.go generated vendored Normal file
View File

@ -0,0 +1,125 @@
package tftp
import (
"fmt"
"io"
"net"
"strconv"
"time"
)
// NewClient creates TFTP client for server on address provided.
func NewClient(addr string) (*Client, error) {
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, fmt.Errorf("resolving address %s: %v", addr, err)
}
return &Client{
addr: a,
timeout: defaultTimeout,
retries: defaultRetries,
}, nil
}
// SetTimeout sets maximum time client waits for single network round-trip to succeed.
// Default is 5 seconds.
func (c *Client) SetTimeout(t time.Duration) {
if t <= 0 {
c.timeout = defaultTimeout
}
c.timeout = t
}
// SetRetries sets maximum number of attempts client made to transmit a packet.
// Default is 5 attempts.
func (c *Client) SetRetries(count int) {
if count < 1 {
c.retries = defaultRetries
}
c.retries = count
}
// SetBackoff sets a user provided function that is called to provide a
// backoff duration prior to retransmitting an unacknowledged packet.
func (c *Client) SetBackoff(h backoffFunc) {
c.backoff = h
}
type Client struct {
addr *net.UDPAddr
timeout time.Duration
retries int
backoff backoffFunc
blksize int
tsize bool
}
// Send starts outgoing file transmission. It returns io.ReaderFrom or error.
func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return nil, err
}
s := &sender{
send: make([]byte, datagramLength),
receive: make([]byte, datagramLength),
conn: conn,
retry: &backoff{handler: c.backoff},
timeout: c.timeout,
retries: c.retries,
addr: c.addr,
mode: mode,
}
if c.blksize != 0 {
s.opts = make(options)
s.opts["blksize"] = strconv.Itoa(c.blksize)
}
n := packRQ(s.send, opWRQ, filename, mode, s.opts)
addr, err := s.sendWithRetry(n)
if err != nil {
return nil, err
}
s.addr = addr
s.opts = nil
return s, nil
}
// Receive starts incoming file transmission. It returns io.WriterTo or error.
func (c Client) Receive(filename string, mode string) (io.WriterTo, error) {
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return nil, err
}
if c.timeout == 0 {
c.timeout = defaultTimeout
}
r := &receiver{
send: make([]byte, datagramLength),
receive: make([]byte, datagramLength),
conn: conn,
retry: &backoff{handler: c.backoff},
timeout: c.timeout,
retries: c.retries,
addr: c.addr,
autoTerm: true,
block: 1,
mode: mode,
}
if c.blksize != 0 || c.tsize {
r.opts = make(options)
}
if c.blksize != 0 {
r.opts["blksize"] = strconv.Itoa(c.blksize)
}
if c.tsize {
r.opts["tsize"] = "0"
}
n := packRQ(r.send, opRRQ, filename, mode, r.opts)
l, addr, err := r.receiveWithRetry(n)
if err != nil {
return nil, err
}
r.l = l
r.addr = addr
return r, nil
}

108
vendor/github.com/pin/tftp/netascii/netascii.go generated vendored Normal file
View File

@ -0,0 +1,108 @@
package netascii
// TODO: make it work not only on linux
import "io"
const (
CR = '\x0d'
LF = '\x0a'
NUL = '\x00'
)
func ToReader(r io.Reader) io.Reader {
return &toReader{
r: r,
buf: make([]byte, 256),
}
}
type toReader struct {
r io.Reader
buf []byte
n int
i int
err error
lf bool
nul bool
}
func (r *toReader) Read(p []byte) (int, error) {
var n int
for n < len(p) {
if r.lf {
p[n] = LF
n++
r.lf = false
continue
}
if r.nul {
p[n] = NUL
n++
r.nul = false
continue
}
if r.i < r.n {
if r.buf[r.i] == LF {
p[n] = CR
r.lf = true
} else if r.buf[r.i] == CR {
p[n] = CR
r.nul = true
} else {
p[n] = r.buf[r.i]
}
r.i++
n++
continue
}
if r.err == nil {
r.n, r.err = r.r.Read(r.buf)
r.i = 0
} else {
return n, r.err
}
}
return n, r.err
}
type fromWriter struct {
w io.Writer
buf []byte
i int
cr bool
}
func FromWriter(w io.Writer) io.Writer {
return &fromWriter{
w: w,
buf: make([]byte, 256),
}
}
func (w *fromWriter) Write(p []byte) (n int, err error) {
for n < len(p) {
if w.cr {
if p[n] == LF {
w.buf[w.i] = LF
}
if p[n] == NUL {
w.buf[w.i] = CR
}
w.cr = false
w.i++
} else if p[n] == CR {
w.cr = true
} else {
w.buf[w.i] = p[n]
w.i++
}
n++
if w.i == len(w.buf) || n == len(p) {
_, err = w.w.Write(w.buf[:w.i])
w.i = 0
}
}
return n, err
}

89
vendor/github.com/pin/tftp/netascii/netascii_test.go generated vendored Normal file
View File

@ -0,0 +1,89 @@
package netascii
import (
"bytes"
"io/ioutil"
"strings"
"testing"
"testing/iotest"
)
var basic = map[string]string{
"\r": "\r\x00",
"\n": "\r\n",
"la\nbu": "la\r\nbu",
"la\rbu": "la\r\x00bu",
"\r\r\r": "\r\x00\r\x00\r\x00",
"\n\n\n": "\r\n\r\n\r\n",
}
func TestTo(t *testing.T) {
for text, netascii := range basic {
to := ToReader(strings.NewReader(text))
n, _ := ioutil.ReadAll(to)
if bytes.Compare(n, []byte(netascii)) != 0 {
t.Errorf("%q to netascii: %q != %q", text, n, netascii)
}
}
}
func TestFrom(t *testing.T) {
for text, netascii := range basic {
r := bytes.NewReader([]byte(netascii))
b := &bytes.Buffer{}
from := FromWriter(b)
r.WriteTo(from)
n, _ := ioutil.ReadAll(b)
if string(n) != text {
t.Errorf("%q from netascii: %q != %q", netascii, n, text)
}
}
}
const text = `
Therefore, the sequence "CR LF" must be treated as a single "new
line" character and used whenever their combined action is
intended; the sequence "CR NUL" must be used where a carriage
return alone is actually desired; and the CR character must be
avoided in other contexts. This rule gives assurance to systems
which must decide whether to perform a "new line" function or a
multiple-backspace that the TELNET stream contains a character
following a CR that will allow a rational decision.
(in the default ASCII mode), to preserve the symmetry of the
NVT model. Even though it may be known in some situations
(e.g., with remote echo and suppress go ahead options in
effect) that characters are not being sent to an actual
printer, nonetheless, for the sake of consistency, the protocol
requires that a NUL be inserted following a CR not followed by
a LF in the data stream. The converse of this is that a NUL
received in the data stream after a CR (in the absence of
options negotiations which explicitly specify otherwise) should
be stripped out prior to applying the NVT to local character
set mapping.
`
func TestWriteRead(t *testing.T) {
var one bytes.Buffer
to := ToReader(strings.NewReader(text))
one.ReadFrom(to)
two := &bytes.Buffer{}
from := FromWriter(two)
one.WriteTo(from)
text2, _ := ioutil.ReadAll(two)
if text != string(text2) {
t.Errorf("text mismatch \n%x \n%x", text, text2)
}
}
func TestOneByte(t *testing.T) {
var one bytes.Buffer
to := iotest.OneByteReader(ToReader(strings.NewReader(text)))
one.ReadFrom(to)
two := &bytes.Buffer{}
from := FromWriter(two)
one.WriteTo(from)
text2, _ := ioutil.ReadAll(two)
if text != string(text2) {
t.Errorf("text mismatch \n%x \n%x", text, text2)
}
}

190
vendor/github.com/pin/tftp/packet.go generated vendored Normal file
View File

@ -0,0 +1,190 @@
package tftp
import (
"bytes"
"encoding/binary"
"fmt"
)
const (
opRRQ = uint16(1) // Read request (RRQ)
opWRQ = uint16(2) // Write request (WRQ)
opDATA = uint16(3) // Data
opACK = uint16(4) // Acknowledgement
opERROR = uint16(5) // Error
opOACK = uint16(6) // Options Acknowledgment
)
const (
blockLength = 512
datagramLength = 516
)
type options map[string]string
// RRQ/WRQ packet
//
// 2 bytes string 1 byte string 1 byte
// --------------------------------------------------
// | Opcode | Filename | 0 | Mode | 0 |
// --------------------------------------------------
type pRRQ []byte
type pWRQ []byte
// packRQ returns length of the packet in b
func packRQ(p []byte, op uint16, filename, mode string, opts options) int {
binary.BigEndian.PutUint16(p, op)
n := 2
n += copy(p[2:len(p)-10], filename)
p[n] = 0
n++
n += copy(p[n:], mode)
p[n] = 0
n++
for name, value := range opts {
n += copy(p[n:], name)
p[n] = 0
n++
n += copy(p[n:], value)
p[n] = 0
n++
}
return n
}
func unpackRQ(p []byte) (filename, mode string, opts options, err error) {
bs := bytes.Split(p[2:], []byte{0})
if len(bs) < 2 {
return "", "", nil, fmt.Errorf("missing filename or mode")
}
filename = string(bs[0])
mode = string(bs[1])
if len(bs) < 4 {
return filename, mode, nil, nil
}
opts = make(options)
for i := 2; i+1 < len(bs); i += 2 {
opts[string(bs[i])] = string(bs[i+1])
}
return filename, mode, opts, nil
}
// OACK packet
//
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
// | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
type pOACK []byte
func packOACK(p []byte, opts options) int {
binary.BigEndian.PutUint16(p, opOACK)
n := 2
for name, value := range opts {
n += copy(p[n:], name)
p[n] = 0
n++
n += copy(p[n:], value)
p[n] = 0
n++
}
return n
}
func unpackOACK(p []byte) (opts options, err error) {
bs := bytes.Split(p[2:], []byte{0})
opts = make(options)
for i := 0; i+1 < len(bs); i += 2 {
opts[string(bs[i])] = string(bs[i+1])
}
return opts, nil
}
// ERROR packet
//
// 2 bytes 2 bytes string 1 byte
// ------------------------------------------
// | Opcode | ErrorCode | ErrMsg | 0 |
// ------------------------------------------
type pERROR []byte
func packERROR(p []byte, code uint16, message string) int {
binary.BigEndian.PutUint16(p, opERROR)
binary.BigEndian.PutUint16(p[2:], code)
n := copy(p[4:len(p)-2], message)
p[4+n] = 0
return n + 5
}
func (p pERROR) code() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p pERROR) message() string {
return string(p[4:])
}
// DATA packet
//
// 2 bytes 2 bytes n bytes
// ----------------------------------
// | Opcode | Block # | Data |
// ----------------------------------
type pDATA []byte
func (p pDATA) block() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
// ACK packet
//
// 2 bytes 2 bytes
// -----------------------
// | Opcode | Block # |
// -----------------------
type pACK []byte
func (p pACK) block() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func parsePacket(p []byte) (interface{}, error) {
l := len(p)
if l < 2 {
return nil, fmt.Errorf("short packet")
}
opcode := binary.BigEndian.Uint16(p)
switch opcode {
case opRRQ:
if l < 4 {
return nil, fmt.Errorf("short RRQ packet: %d", l)
}
return pRRQ(p), nil
case opWRQ:
if l < 4 {
return nil, fmt.Errorf("short WRQ packet: %d", l)
}
return pWRQ(p), nil
case opDATA:
if l < 4 {
return nil, fmt.Errorf("short DATA packet: %d", l)
}
return pDATA(p), nil
case opACK:
if l < 4 {
return nil, fmt.Errorf("short ACK packet: %d", l)
}
return pACK(p), nil
case opERROR:
if l < 5 {
return nil, fmt.Errorf("short ERROR packet: %d", l)
}
return pERROR(p), nil
case opOACK:
if l < 6 {
return nil, fmt.Errorf("short OACK packet: %d", l)
}
return pOACK(p), nil
default:
return nil, fmt.Errorf("unknown opcode: %d", opcode)
}
}

234
vendor/github.com/pin/tftp/receiver.go generated vendored Normal file
View File

@ -0,0 +1,234 @@
package tftp
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"time"
"github.com/pin/tftp/netascii"
)
// IncomingTransfer provides methods that expose information associated with
// an incoming transfer.
type IncomingTransfer interface {
// Size returns the size of an incoming file if the request included the
// tsize option (see RFC2349). To differentiate a zero-sized file transfer
// from a request without tsize use the second boolean "ok" return value.
Size() (n int64, ok bool)
// RemoteAddr returns the remote peer's IP address and port.
RemoteAddr() net.UDPAddr
}
func (r *receiver) RemoteAddr() net.UDPAddr { return *r.addr }
func (r *receiver) Size() (n int64, ok bool) {
if r.opts != nil {
if s, ok := r.opts["tsize"]; ok {
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, false
}
return n, true
}
}
return 0, false
}
type receiver struct {
send []byte
receive []byte
addr *net.UDPAddr
tid int
conn *net.UDPConn
block uint16
retry *backoff
timeout time.Duration
retries int
l int
autoTerm bool
dally bool
mode string
opts options
}
func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
if r.mode == "netascii" {
w = netascii.FromWriter(w)
}
if r.opts != nil {
err := r.sendOptions()
if err != nil {
r.abort(err)
return 0, err
}
}
binary.BigEndian.PutUint16(r.send[0:2], opACK)
for {
if r.l > 0 {
l, err := w.Write(r.receive[4:r.l])
n += int64(l)
if err != nil {
r.abort(err)
return n, err
}
if r.l < len(r.receive) {
if r.autoTerm {
r.terminate()
r.conn.Close()
}
return n, nil
}
}
binary.BigEndian.PutUint16(r.send[2:4], r.block)
r.block++ // send ACK for current block and expect next one
ll, _, err := r.receiveWithRetry(4)
if err != nil {
r.abort(err)
return n, err
}
r.l = ll
}
}
func (r *receiver) sendOptions() error {
for name, value := range r.opts {
if name == "blksize" {
err := r.setBlockSize(value)
if err != nil {
delete(r.opts, name)
continue
}
} else {
delete(r.opts, name)
}
}
if len(r.opts) > 0 {
m := packOACK(r.send, r.opts)
r.block = 1 // expect data block number 1
ll, _, err := r.receiveWithRetry(m)
if err != nil {
r.abort(err)
return err
}
r.l = ll
}
return nil
}
func (r *receiver) setBlockSize(blksize string) error {
n, err := strconv.Atoi(blksize)
if err != nil {
return err
}
if n < 512 {
return fmt.Errorf("blkzise too small: %d", n)
}
if n > 65464 {
return fmt.Errorf("blksize too large: %d", n)
}
r.receive = make([]byte, n+4)
return nil
}
func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) {
r.retry.reset()
for {
n, addr, err := r.receiveDatagram(l)
if _, ok := err.(net.Error); ok && r.retry.count() < r.retries {
r.retry.backoff()
continue
}
return n, addr, err
}
}
func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) {
err := r.conn.SetReadDeadline(time.Now().Add(r.timeout))
if err != nil {
return 0, nil, err
}
_, err = r.conn.WriteToUDP(r.send[:l], r.addr)
if err != nil {
return 0, nil, err
}
for {
c, addr, err := r.conn.ReadFromUDP(r.receive)
if err != nil {
return 0, nil, err
}
if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) {
continue
}
p, err := parsePacket(r.receive[:c])
if err != nil {
return 0, addr, err
}
r.tid = addr.Port
switch p := p.(type) {
case pDATA:
if p.block() == r.block {
return c, addr, nil
}
case pOACK:
opts, err := unpackOACK(p)
if r.block != 1 {
continue
}
if err != nil {
r.abort(err)
return 0, addr, err
}
for name, value := range opts {
if name == "blksize" {
err := r.setBlockSize(value)
if err != nil {
continue
}
}
}
r.block = 0 // ACK with block number 0
r.opts = opts
return 0, addr, nil
case pERROR:
return 0, addr, fmt.Errorf("code: %d, message: %s",
p.code(), p.message())
}
}
}
func (r *receiver) terminate() error {
binary.BigEndian.PutUint16(r.send[2:4], r.block)
if r.dally {
for i := 0; i < 3; i++ {
_, _, err := r.receiveDatagram(4)
if err != nil {
return nil
}
}
return fmt.Errorf("dallying termination failed")
} else {
_, err := r.conn.WriteToUDP(r.send[:4], r.addr)
if err != nil {
return err
}
}
return nil
}
func (r *receiver) abort(err error) error {
if r.conn == nil {
return nil
}
n := packERROR(r.send, 1, err.Error())
_, err = r.conn.WriteToUDP(r.send[:n], r.addr)
if err != nil {
return err
}
r.conn.Close()
r.conn = nil
return nil
}

243
vendor/github.com/pin/tftp/sender.go generated vendored Normal file
View File

@ -0,0 +1,243 @@
package tftp
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"time"
"github.com/pin/tftp/netascii"
)
// OutgoingTransfer provides methods to set the outgoing transfer size and
// retrieve the remote address of the peer.
type OutgoingTransfer interface {
// SetSize is used to set the outgoing transfer size (tsize option: RFC2349)
// manually in a server write transfer handler.
//
// It is not necessary in most cases; when the io.Reader provided to
// ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will
// be determined automatically. Seek will not be attempted when the
// transfer size option is set with SetSize.
//
// The value provided will be used only if SetSize is called before ReadFrom
// and only on in a server read handler.
SetSize(n int64)
// RemoteAddr returns the remote peer's IP address and port.
RemoteAddr() net.UDPAddr
}
type sender struct {
conn *net.UDPConn
addr *net.UDPAddr
tid int
send []byte
receive []byte
retry *backoff
timeout time.Duration
retries int
block uint16
mode string
opts options
}
func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr }
func (s *sender) SetSize(n int64) {
if s.opts != nil {
if _, ok := s.opts["tsize"]; ok {
s.opts["tsize"] = strconv.FormatInt(n, 10)
}
}
}
func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
if s.mode == "netascii" {
r = netascii.ToReader(r)
}
if s.opts != nil {
// check that tsize is set
if ts, ok := s.opts["tsize"]; ok {
// check that tsize is not set with SetSize already
i, err := strconv.ParseInt(ts, 10, 64)
if err == nil && i == 0 {
if rs, ok := r.(io.Seeker); ok {
pos, err := rs.Seek(0, 1)
if err != nil {
return 0, err
}
size, err := rs.Seek(0, 2)
if err != nil {
return 0, err
}
s.opts["tsize"] = strconv.FormatInt(size, 10)
_, err = rs.Seek(pos, 0)
if err != nil {
return 0, err
}
}
}
}
err = s.sendOptions()
if err != nil {
s.abort(err)
return 0, err
}
}
s.block = 1 // start data transmission with block 1
binary.BigEndian.PutUint16(s.send[0:2], opDATA)
for {
l, err := io.ReadFull(r, s.send[4:])
n += int64(l)
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
binary.BigEndian.PutUint16(s.send[2:4], s.block)
_, err = s.sendWithRetry(4)
if err != nil {
s.abort(err)
return n, err
}
s.conn.Close()
return n, nil
}
s.abort(err)
return n, err
}
binary.BigEndian.PutUint16(s.send[2:4], s.block)
_, err = s.sendWithRetry(4 + l)
if err != nil {
s.abort(err)
return n, err
}
if l < len(s.send)-4 {
s.conn.Close()
return n, nil
}
s.block++
}
}
func (s *sender) sendOptions() error {
for name, value := range s.opts {
if name == "blksize" {
err := s.setBlockSize(value)
if err != nil {
delete(s.opts, name)
continue
}
} else if name == "tsize" {
if value != "0" {
s.opts["tsize"] = value
} else {
delete(s.opts, name)
continue
}
} else {
delete(s.opts, name)
}
}
if len(s.opts) > 0 {
m := packOACK(s.send, s.opts)
_, err := s.sendWithRetry(m)
if err != nil {
return err
}
}
return nil
}
func (s *sender) setBlockSize(blksize string) error {
n, err := strconv.Atoi(blksize)
if err != nil {
return err
}
if n < 512 {
return fmt.Errorf("blkzise too small: %d", n)
}
if n > 65464 {
return fmt.Errorf("blksize too large: %d", n)
}
s.send = make([]byte, n+4)
return nil
}
func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) {
s.retry.reset()
for {
addr, err := s.sendDatagram(l)
if _, ok := err.(net.Error); ok && s.retry.count() < s.retries {
s.retry.backoff()
continue
}
return addr, err
}
}
func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
err := s.conn.SetReadDeadline(time.Now().Add(s.timeout))
if err != nil {
return nil, err
}
_, err = s.conn.WriteToUDP(s.send[:l], s.addr)
if err != nil {
return nil, err
}
for {
n, addr, err := s.conn.ReadFromUDP(s.receive)
if err != nil {
return nil, err
}
if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) {
continue
}
p, err := parsePacket(s.receive[:n])
if err != nil {
continue
}
s.tid = addr.Port
switch p := p.(type) {
case pACK:
if p.block() == s.block {
return addr, nil
}
case pOACK:
opts, err := unpackOACK(p)
if s.block != 0 {
continue
}
if err != nil {
s.abort(err)
return addr, err
}
for name, value := range opts {
if name == "blksize" {
err := s.setBlockSize(value)
if err != nil {
continue
}
}
}
return addr, nil
case pERROR:
return nil, fmt.Errorf("sending block %d: code=%d, error: %s",
s.block, p.code(), p.message())
}
}
}
func (s *sender) abort(err error) error {
if s.conn == nil {
return nil
}
n := packERROR(s.send, 1, err.Error())
_, err = s.conn.WriteToUDP(s.send[:n], s.addr)
if err != nil {
return err
}
s.conn.Close()
s.conn = nil
return nil
}

199
vendor/github.com/pin/tftp/server.go generated vendored Normal file
View File

@ -0,0 +1,199 @@
package tftp
import (
"fmt"
"io"
"net"
"sync"
"time"
)
// NewServer creates TFTP server. It requires two functions to handle
// read and write requests.
// In case nil is provided for read or write handler the respective
// operation is disabled.
func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
writeHandler func(filename string, wt io.WriterTo) error) *Server {
return &Server{
readHandler: readHandler,
writeHandler: writeHandler,
timeout: defaultTimeout,
retries: defaultRetries,
}
}
type Server struct {
readHandler func(filename string, rf io.ReaderFrom) error
writeHandler func(filename string, wt io.WriterTo) error
backoff backoffFunc
conn *net.UDPConn
quit chan chan struct{}
wg sync.WaitGroup
timeout time.Duration
retries int
}
// SetTimeout sets maximum time server waits for single network
// round-trip to succeed.
// Default is 5 seconds.
func (s *Server) SetTimeout(t time.Duration) {
if t <= 0 {
s.timeout = defaultTimeout
} else {
s.timeout = t
}
}
// SetRetries sets maximum number of attempts server made to transmit a
// packet.
// Default is 5 attempts.
func (s *Server) SetRetries(count int) {
if count < 1 {
s.retries = defaultRetries
} else {
s.retries = count
}
}
// SetBackoff sets a user provided function that is called to provide a
// backoff duration prior to retransmitting an unacknowledged packet.
func (s *Server) SetBackoff(h backoffFunc) {
s.backoff = h
}
// ListenAndServe binds to address provided and start the server.
// ListenAndServe returns when Shutdown is called.
func (s *Server) ListenAndServe(addr string) error {
a, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", a)
if err != nil {
return err
}
s.Serve(conn)
return nil
}
// Serve starts server provided already opened UDP connecton. It is
// useful for the case when you want to run server in separate goroutine
// but still want to be able to handle any errors opening connection.
// Serve returns when Shutdown is called or connection is closed.
func (s *Server) Serve(conn *net.UDPConn) {
s.conn = conn
s.quit = make(chan chan struct{})
for {
select {
case q := <-s.quit:
q <- struct{}{}
return
default:
err := s.processRequest(s.conn)
if err != nil {
// TODO: add logging handler
}
}
}
}
// Shutdown make server stop listening for new requests, allows
// server to finish outstanding transfers and stops server.
func (s *Server) Shutdown() {
s.conn.Close()
q := make(chan struct{})
s.quit <- q
<-q
s.wg.Wait()
}
func (s *Server) processRequest(conn *net.UDPConn) error {
var buffer []byte
buffer = make([]byte, datagramLength)
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
return fmt.Errorf("reading UDP: %v", err)
}
p, err := parsePacket(buffer[:n])
if err != nil {
return err
}
switch p := p.(type) {
case pWRQ:
filename, mode, opts, err := unpackRQ(p)
if err != nil {
return fmt.Errorf("unpack WRQ: %v", err)
}
//fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return err
}
if err != nil {
return fmt.Errorf("open transmission: %v", err)
}
wt := &receiver{
send: make([]byte, datagramLength),
receive: make([]byte, datagramLength),
conn: conn,
retry: &backoff{handler: s.backoff},
timeout: s.timeout,
retries: s.retries,
addr: remoteAddr,
mode: mode,
opts: opts,
}
s.wg.Add(1)
go func() {
if s.writeHandler != nil {
err := s.writeHandler(filename, wt)
if err != nil {
wt.abort(err)
} else {
wt.terminate()
wt.conn.Close()
}
} else {
wt.abort(fmt.Errorf("server does not support write requests"))
}
s.wg.Done()
}()
case pRRQ:
filename, mode, opts, err := unpackRQ(p)
if err != nil {
return fmt.Errorf("unpack RRQ: %v", err)
}
//fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return err
}
rf := &sender{
send: make([]byte, datagramLength),
receive: make([]byte, datagramLength),
tid: remoteAddr.Port,
conn: conn,
retry: &backoff{handler: s.backoff},
timeout: s.timeout,
retries: s.retries,
addr: remoteAddr,
mode: mode,
opts: opts,
}
s.wg.Add(1)
go func() {
if s.readHandler != nil {
err := s.readHandler(filename, rf)
if err != nil {
rf.abort(err)
}
} else {
rf.abort(fmt.Errorf("server does not support read requests"))
}
s.wg.Done()
}()
default:
return fmt.Errorf("unexpected %T", p)
}
return nil
}

622
vendor/github.com/pin/tftp/tftp_test.go generated vendored Normal file
View File

@ -0,0 +1,622 @@
package tftp
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"os"
"strconv"
"sync"
"testing"
"testing/iotest"
"time"
)
var localhost string = determineLocalhost()
func determineLocalhost() string {
l, err := net.ListenTCP("tcp", nil)
if err != nil {
panic(fmt.Sprintf("ListenTCP error: %s", err))
}
_, lport, _ := net.SplitHostPort(l.Addr().String())
defer l.Close()
lo := make(chan string)
go func() {
for {
conn, err := l.Accept()
if err != nil {
break
}
conn.Close()
}
}()
go func() {
port, _ := strconv.Atoi(lport)
for _, af := range []string{"tcp6", "tcp4"} {
conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port})
if err == nil {
conn.Close()
host, _, _ := net.SplitHostPort(conn.LocalAddr().String())
lo <- host
return
}
}
panic("could not determine address family")
}()
return <-lo
}
func localSystem(c *net.UDPConn) string {
_, port, _ := net.SplitHostPort(c.LocalAddr().String())
return net.JoinHostPort(localhost, port)
}
func TestPackUnpack(t *testing.T) {
v := []string{"test-filename/with-subdir"}
testOptsList := []options{
nil,
options{
"tsize": "1234",
"blksize": "22",
},
}
for _, filename := range v {
for _, mode := range []string{"octet", "netascii"} {
for _, opts := range testOptsList {
packUnpack(t, filename, mode, opts)
}
}
}
}
func packUnpack(t *testing.T, filename, mode string, opts options) {
b := make([]byte, datagramLength)
for _, op := range []uint16{opRRQ, opWRQ} {
n := packRQ(b, op, filename, mode, opts)
f, m, o, err := unpackRQ(b[:n])
if err != nil {
t.Errorf("%s pack/unpack: %v", filename, err)
}
if f != filename {
t.Errorf("filename mismatch (%s): '%x' vs '%x'",
filename, f, filename)
}
if m != mode {
t.Errorf("mode mismatch (%s): '%x' vs '%x'",
mode, m, mode)
}
if opts != nil {
for name, value := range opts {
v, ok := o[name]
if !ok {
t.Errorf("missing %s option", name)
}
if v != value {
t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value)
}
}
}
}
}
func TestZeroLength(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
testSendReceive(t, c, 0)
}
func Test900(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
for i := 600; i < 4000; i += 1 {
c.blksize = i
testSendReceive(t, c, 9000+int64(i))
}
}
func Test1000(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
for i := int64(0); i < 5000; i++ {
filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano())
rf, err := c.Send(filename, "octet")
if err != nil {
t.Fatalf("requesting %s write: %v", filename, err)
}
r := io.LimitReader(newRandReader(rand.NewSource(i)), i)
n, err := rf.ReadFrom(r)
if err != nil {
t.Fatalf("sending %s: %v", filename, err)
}
if n != i {
t.Errorf("%s length mismatch: %d != %d", filename, n, i)
}
}
}
func Test1810(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
c.blksize = 1810
testSendReceive(t, c, 9000+1810)
}
func TestTSize(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
c.tsize = true
testSendReceive(t, c, 640)
}
func TestNearBlockLength(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
for i := 450; i < 520; i++ {
testSendReceive(t, c, int64(i))
}
}
func TestBlockWrapsAround(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
n := 65535 * 512
for i := n - 2; i < n+2; i++ {
testSendReceive(t, c, int64(i))
}
}
func TestRandomLength(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
r := rand.New(rand.NewSource(42))
for i := 0; i < 100; i++ {
testSendReceive(t, c, r.Int63n(100000))
}
}
func TestBigFile(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
testSendReceive(t, c, 3*1000*1000)
}
func TestByOneByte(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
filename := "test-by-one-byte"
mode := "octet"
const length = 80000
sender, err := c.Send(filename, mode)
if err != nil {
t.Fatalf("requesting write: %v", err)
}
r := iotest.OneByteReader(io.LimitReader(
newRandReader(rand.NewSource(42)), length))
n, err := sender.ReadFrom(r)
if err != nil {
t.Fatalf("send error: %v", err)
}
if n != length {
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
}
readTransfer, err := c.Receive(filename, mode)
if err != nil {
t.Fatalf("requesting read %s: %v", filename, err)
}
buf := &bytes.Buffer{}
n, err = readTransfer.WriteTo(buf)
if err != nil {
t.Fatalf("%s read error: %v", filename, err)
}
if n != length {
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
}
bs, _ := ioutil.ReadAll(io.LimitReader(
newRandReader(rand.NewSource(42)), length))
if !bytes.Equal(bs, buf.Bytes()) {
t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
}
}
func TestDuplicate(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
filename := "test-duplicate"
mode := "octet"
bs := []byte("lalala")
sender, err := c.Send(filename, mode)
if err != nil {
t.Fatalf("requesting write: %v", err)
}
buf := bytes.NewBuffer(bs)
_, err = sender.ReadFrom(buf)
if err != nil {
t.Fatalf("send error: %v", err)
}
sender, err = c.Send(filename, mode)
if err == nil {
t.Fatalf("file already exists")
}
t.Logf("sending file that already exists: %v", err)
}
func TestNotFound(t *testing.T) {
s, c := makeTestServer()
defer s.Shutdown()
filename := "test-not-exists"
mode := "octet"
_, err := c.Receive(filename, mode)
if err == nil {
t.Fatalf("file not exists", err)
}
t.Logf("receiving file that does not exist: %v", err)
}
func testSendReceive(t *testing.T, client *Client, length int64) {
filename := fmt.Sprintf("length-%d-bytes", length)
mode := "octet"
writeTransfer, err := client.Send(filename, mode)
if err != nil {
t.Fatalf("requesting write %s: %v", filename, err)
}
r := io.LimitReader(newRandReader(rand.NewSource(42)), length)
n, err := writeTransfer.ReadFrom(r)
if err != nil {
t.Fatalf("%s write error: %v", filename, err)
}
if n != length {
t.Errorf("%s write length mismatch: %d != %d", filename, n, length)
}
readTransfer, err := client.Receive(filename, mode)
if err != nil {
t.Fatalf("requesting read %s: %v", filename, err)
}
if it, ok := readTransfer.(IncomingTransfer); ok {
if n, ok := it.Size(); ok {
fmt.Printf("Transfer size: %d\n", n)
if n != length {
t.Errorf("tsize mismatch: %d vs %d", n, length)
}
}
}
buf := &bytes.Buffer{}
n, err = readTransfer.WriteTo(buf)
if err != nil {
t.Fatalf("%s read error: %v", filename, err)
}
if n != length {
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
}
bs, _ := ioutil.ReadAll(io.LimitReader(
newRandReader(rand.NewSource(42)), length))
if !bytes.Equal(bs, buf.Bytes()) {
t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
}
}
func TestSendTsizeFromSeek(t *testing.T) {
// create read-only server
s := NewServer(func(filename string, rf io.ReaderFrom) error {
b := make([]byte, 100)
rr := newRandReader(rand.NewSource(42))
rr.Read(b)
// bytes.Reader implements io.Seek
r := bytes.NewReader(b)
_, err := rf.ReadFrom(r)
if err != nil {
t.Errorf("sending bytes: %v", err)
}
return nil
}, nil)
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
t.Fatalf("listening: %v", err)
}
go s.Serve(conn)
defer s.Shutdown()
c, _ := NewClient(localSystem(conn))
c.tsize = true
r, _ := c.Receive("f", "octet")
var size int64
if t, ok := r.(IncomingTransfer); ok {
if n, ok := t.Size(); ok {
size = n
fmt.Printf("Transfer size: %d\n", n)
}
}
if size != 100 {
t.Errorf("size expected: 100, got %d", size)
}
r.WriteTo(ioutil.Discard)
}
type testBackend struct {
m map[string][]byte
mu sync.Mutex
}
func makeTestServer() (*Server, *Client) {
b := &testBackend{}
b.m = make(map[string][]byte)
// Create server
s := NewServer(b.handleRead, b.handleWrite)
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
panic(err)
}
go s.Serve(conn)
// Create client for that server
c, err := NewClient(localSystem(conn))
if err != nil {
panic(err)
}
return s, c
}
func TestNoHandlers(t *testing.T) {
s := NewServer(nil, nil)
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
panic(err)
}
go s.Serve(conn)
c, err := NewClient(localSystem(conn))
if err != nil {
panic(err)
}
_, err = c.Send("test", "octet")
if err == nil {
t.Errorf("error expected")
}
_, err = c.Receive("test", "octet")
if err == nil {
t.Errorf("error expected")
}
}
func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error {
b.mu.Lock()
defer b.mu.Unlock()
_, ok := b.m[filename]
if ok {
fmt.Fprintf(os.Stderr, "File %s already exists\n", filename)
return fmt.Errorf("file already exists")
}
if t, ok := wt.(IncomingTransfer); ok {
if n, ok := t.Size(); ok {
fmt.Printf("Transfer size: %d\n", n)
}
}
buf := &bytes.Buffer{}
_, err := wt.WriteTo(buf)
if err != nil {
fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err)
return err
}
b.m[filename] = buf.Bytes()
return nil
}
func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error {
b.mu.Lock()
defer b.mu.Unlock()
bs, ok := b.m[filename]
if !ok {
fmt.Fprintf(os.Stderr, "File %s not found\n", filename)
return fmt.Errorf("file not found")
}
if t, ok := rf.(OutgoingTransfer); ok {
t.SetSize(int64(len(bs)))
}
_, err := rf.ReadFrom(bytes.NewBuffer(bs))
if err != nil {
fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err)
return err
}
return nil
}
type randReader struct {
src rand.Source
next int64
i int8
}
func newRandReader(src rand.Source) io.Reader {
r := &randReader{
src: src,
next: src.Int63(),
}
return r
}
func (r *randReader) Read(p []byte) (n int, err error) {
next, i := r.next, r.i
for n = 0; n < len(p); n++ {
if i == 7 {
next, i = r.src.Int63(), 0
}
p[n] = byte(next)
next >>= 8
i++
}
r.next, r.i = next, i
return
}
func TestServerSendTimeout(t *testing.T) {
s, c := makeTestServer()
s.SetTimeout(time.Second)
s.SetRetries(2)
var serverErr error
s.readHandler = func(filename string, rf io.ReaderFrom) error {
r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
_, serverErr = rf.ReadFrom(r)
return serverErr
}
defer s.Shutdown()
filename := "test-server-send-timeout"
mode := "octet"
readTransfer, err := c.Receive(filename, mode)
if err != nil {
t.Fatalf("requesting read %s: %v", filename, err)
}
w := &slowWriter{
n: 3,
delay: 8 * time.Second,
}
_, _ = readTransfer.WriteTo(w)
netErr, ok := serverErr.(net.Error)
if !ok {
t.Fatalf("network error expected: %T", serverErr)
}
if !netErr.Timeout() {
t.Fatalf("timout is expected: %v", serverErr)
}
}
func TestServerReceiveTimeout(t *testing.T) {
s, c := makeTestServer()
s.SetTimeout(time.Second)
s.SetRetries(2)
var serverErr error
s.writeHandler = func(filename string, wt io.WriterTo) error {
buf := &bytes.Buffer{}
_, serverErr = wt.WriteTo(buf)
return serverErr
}
defer s.Shutdown()
filename := "test-server-receive-timeout"
mode := "octet"
writeTransfer, err := c.Send(filename, mode)
if err != nil {
t.Fatalf("requesting write %s: %v", filename, err)
}
r := &slowReader{
r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
n: 3,
delay: 8 * time.Second,
}
_, _ = writeTransfer.ReadFrom(r)
netErr, ok := serverErr.(net.Error)
if !ok {
t.Fatalf("network error expected: %T", serverErr)
}
if !netErr.Timeout() {
t.Fatalf("timout is expected: %v", serverErr)
}
}
func TestClientReceiveTimeout(t *testing.T) {
s, c := makeTestServer()
c.SetTimeout(time.Second)
c.SetRetries(2)
s.readHandler = func(filename string, rf io.ReaderFrom) error {
r := &slowReader{
r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
n: 3,
delay: 8 * time.Second,
}
_, err := rf.ReadFrom(r)
return err
}
defer s.Shutdown()
filename := "test-client-receive-timeout"
mode := "octet"
readTransfer, err := c.Receive(filename, mode)
if err != nil {
t.Fatalf("requesting read %s: %v", filename, err)
}
buf := &bytes.Buffer{}
_, err = readTransfer.WriteTo(buf)
netErr, ok := err.(net.Error)
if !ok {
t.Fatalf("network error expected: %T", err)
}
if !netErr.Timeout() {
t.Fatalf("timout is expected: %v", err)
}
}
func TestClientSendTimeout(t *testing.T) {
s, c := makeTestServer()
c.SetTimeout(time.Second)
c.SetRetries(2)
s.writeHandler = func(filename string, wt io.WriterTo) error {
w := &slowWriter{
n: 3,
delay: 8 * time.Second,
}
_, err := wt.WriteTo(w)
return err
}
defer s.Shutdown()
filename := "test-client-send-timeout"
mode := "octet"
writeTransfer, err := c.Send(filename, mode)
if err != nil {
t.Fatalf("requesting write %s: %v", filename, err)
}
r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
_, err = writeTransfer.ReadFrom(r)
netErr, ok := err.(net.Error)
if !ok {
t.Fatalf("network error expected: %T", err)
}
if !netErr.Timeout() {
t.Fatalf("timout is expected: %v", err)
}
}
type slowReader struct {
r io.Reader
n int64
delay time.Duration
}
func (r *slowReader) Read(p []byte) (n int, err error) {
if r.n > 0 {
r.n--
return r.r.Read(p)
}
time.Sleep(r.delay)
return r.r.Read(p)
}
type slowWriter struct {
r io.Reader
n int64
delay time.Duration
}
func (r *slowWriter) Write(p []byte) (n int, err error) {
if r.n > 0 {
r.n--
return len(p), nil
}
time.Sleep(r.delay)
return len(p), nil
}