mirror of
https://github.com/rancher/os.git
synced 2025-07-15 23:55:51 +00:00
Added tftp datasource for cloud config.
This commit is contained in:
parent
8b75752225
commit
66c5f6130a
@ -38,6 +38,7 @@ import (
|
||||
"github.com/rancher/os/config/cloudinit/datasource/metadata/gce"
|
||||
"github.com/rancher/os/config/cloudinit/datasource/metadata/packet"
|
||||
"github.com/rancher/os/config/cloudinit/datasource/proccmdline"
|
||||
"github.com/rancher/os/config/cloudinit/datasource/tftp"
|
||||
"github.com/rancher/os/config/cloudinit/datasource/url"
|
||||
"github.com/rancher/os/config/cloudinit/datasource/vmware"
|
||||
"github.com/rancher/os/config/cloudinit/pkg"
|
||||
@ -237,6 +238,8 @@ func getDatasources(datasources []string) []datasource.Datasource {
|
||||
if root != "" {
|
||||
dss = append(dss, file.NewDatasource(root))
|
||||
}
|
||||
case "tftp":
|
||||
dss = append(dss, tftp.NewDatasource(root))
|
||||
case "url":
|
||||
if root != "" {
|
||||
dss = append(dss, url.NewDatasource(root))
|
||||
|
83
config/cloudinit/datasource/tftp/tftp.go
Normal file
83
config/cloudinit/datasource/tftp/tftp.go
Normal file
@ -0,0 +1,83 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/rancher/os/config/cloudinit/datasource"
|
||||
|
||||
"github.com/pin/tftp"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Receive(filename string, mode string) (io.WriterTo, error)
|
||||
}
|
||||
|
||||
type RemoteFile struct {
|
||||
host string
|
||||
path string
|
||||
client Client
|
||||
stream io.WriterTo
|
||||
lastError error
|
||||
}
|
||||
|
||||
func NewDatasource(hostAndPath string) *RemoteFile {
|
||||
parts := strings.SplitN(hostAndPath, "/", 2)
|
||||
|
||||
if len(parts) < 2 {
|
||||
return &RemoteFile{hostAndPath, "", nil, nil, nil}
|
||||
}
|
||||
|
||||
host := parts[0]
|
||||
if match, _ := regexp.MatchString(":[0-9]{2,5}$", host); !match {
|
||||
// No port, using default port 69
|
||||
host += ":69"
|
||||
}
|
||||
|
||||
path := parts[1]
|
||||
if client, lastError := tftp.NewClient(host); lastError == nil {
|
||||
return &RemoteFile{host, path, client, nil, nil}
|
||||
}
|
||||
|
||||
return &RemoteFile{host, path, nil, nil, nil}
|
||||
}
|
||||
|
||||
func (f *RemoteFile) IsAvailable() bool {
|
||||
f.stream, f.lastError = f.client.Receive(f.path, "octet")
|
||||
return f.lastError == nil
|
||||
}
|
||||
|
||||
func (f *RemoteFile) Finish() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *RemoteFile) String() string {
|
||||
return fmt.Sprintf("%s: %s%s (lastError: %v)", f.Type(), f.host, f.path, f.lastError)
|
||||
}
|
||||
|
||||
func (f *RemoteFile) AvailabilityChanges() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *RemoteFile) ConfigRoot() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (f *RemoteFile) FetchMetadata() (datasource.Metadata, error) {
|
||||
return datasource.Metadata{}, nil
|
||||
}
|
||||
|
||||
func (f *RemoteFile) FetchUserdata() ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
|
||||
_, err := f.stream.WriteTo(&b)
|
||||
|
||||
return b.Bytes(), err
|
||||
}
|
||||
|
||||
func (f *RemoteFile) Type() string {
|
||||
return "tftp"
|
||||
}
|
92
config/cloudinit/datasource/tftp/tftp_test.go
Normal file
92
config/cloudinit/datasource/tftp/tftp_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
}
|
||||
|
||||
type mockReceiver struct {
|
||||
}
|
||||
|
||||
func (r mockReceiver) WriteTo(w io.Writer) (n int64, err error) {
|
||||
b := []byte("cloud-config file")
|
||||
w.Write(b)
|
||||
return int64(len(b)), nil
|
||||
}
|
||||
|
||||
func (c mockClient) Receive(filename string, mode string) (io.WriterTo, error) {
|
||||
if filename == "does-not-exist" {
|
||||
return &mockReceiver{}, fmt.Errorf("does not exist")
|
||||
}
|
||||
return &mockReceiver{}, nil
|
||||
}
|
||||
|
||||
var _ Client = (*mockClient)(nil)
|
||||
|
||||
func TestNewDatasource(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
root string
|
||||
expectHost string
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
root: "127.0.0.1/test/file.yaml",
|
||||
expectHost: "127.0.0.1:69",
|
||||
expectPath: "test/file.yaml",
|
||||
},
|
||||
{
|
||||
root: "127.0.0.1/test/file.yaml",
|
||||
expectHost: "127.0.0.1:69",
|
||||
expectPath: "test/file.yaml",
|
||||
},
|
||||
} {
|
||||
ds := NewDatasource(tt.root)
|
||||
if ds.host != tt.expectHost || ds.path != tt.expectPath {
|
||||
t.Fatalf("bad host or path (%q): want host=%s, got %s, path=%s, got %s", tt.root, tt.expectHost, ds.host, tt.expectPath, ds.path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAvailable(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
remoteFile *RemoteFile
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
remoteFile: &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
remoteFile: &RemoteFile{"1.2.3.4", "does-not-exist", &mockClient{}, nil, nil},
|
||||
expect: false,
|
||||
},
|
||||
} {
|
||||
if tt.remoteFile.IsAvailable() != tt.expect {
|
||||
t.Fatalf("expected remote file %s to be %v", tt.remoteFile.path, tt.expect)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchUserdata(t *testing.T) {
|
||||
rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, &mockReceiver{}, nil}
|
||||
b, _ := rf.FetchUserdata()
|
||||
|
||||
expect := []byte("cloud-config file")
|
||||
|
||||
if len(b) != len(expect) || !reflect.DeepEqual(b, expect) {
|
||||
t.Fatalf("expected length of buffer to be %d was %d. Expected %s, got %s", len(expect), len(b), string(expect), string(b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestType(t *testing.T) {
|
||||
rf := &RemoteFile{"1.2.3.4", "test", &mockClient{}, nil, nil}
|
||||
|
||||
if rf.Type() != "tftp" {
|
||||
t.Fatalf("expected remote file Type() to return %s got %s", "tftp", rf.Type())
|
||||
}
|
||||
}
|
@ -57,3 +57,4 @@ golang.org/x/sys eb2c74142fd19a79b3f237334c7384d5167b1b46 https://github.com/gol
|
||||
google.golang.org/grpc ab0be5212fb225475f2087566eded7da5d727960 https://github.com/grpc/grpc-go.git
|
||||
gopkg.in/fsnotify.v1 v1.2.0
|
||||
github.com/fatih/structs dc3312cb1a4513a366c4c9e622ad55c32df12ed3
|
||||
github.com/pin/tftp v2.1.0
|
||||
|
24
vendor/github.com/pin/tftp/.gitignore
generated
vendored
Normal file
24
vendor/github.com/pin/tftp/.gitignore
generated
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
*.test
|
||||
*.prof
|
8
vendor/github.com/pin/tftp/.travis.yml
generated
vendored
Normal file
8
vendor/github.com/pin/tftp/.travis.yml
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
language: go
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
before_install:
|
||||
- ulimit -n 4096
|
4
vendor/github.com/pin/tftp/CONTRIBUTORS
generated
vendored
Normal file
4
vendor/github.com/pin/tftp/CONTRIBUTORS
generated
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
Dmitri Popov
|
||||
Mojo Talantikite
|
||||
Giovanni Bajo
|
||||
Andrew Danforth
|
21
vendor/github.com/pin/tftp/LICENSE
generated
vendored
Normal file
21
vendor/github.com/pin/tftp/LICENSE
generated
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
Copyright (c) 2016 Dmitri Popov
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
171
vendor/github.com/pin/tftp/README.md
generated
vendored
Normal file
171
vendor/github.com/pin/tftp/README.md
generated
vendored
Normal file
@ -0,0 +1,171 @@
|
||||
TFTP server and client library for Golang
|
||||
=========================================
|
||||
|
||||
[](https://godoc.org/github.com/pin/tftp)
|
||||
[](https://travis-ci.org/pin/tftp)
|
||||
|
||||
Implements:
|
||||
* [RFC 1350](https://tools.ietf.org/html/rfc1350) - The TFTP Protocol (Revision 2)
|
||||
* [RFC 2347](https://tools.ietf.org/html/rfc2347) - TFTP Option Extension
|
||||
* [RFC 2348](https://tools.ietf.org/html/rfc2348) - TFTP Blocksize Option
|
||||
|
||||
Partially implements (tsize server side only):
|
||||
* [RFC 2349](https://tools.ietf.org/html/rfc2349) - TFTP Timeout Interval and Transfer Size Options
|
||||
|
||||
Set of features is sufficient for PXE boot support.
|
||||
|
||||
``` go
|
||||
import "github.com/pin/tftp"
|
||||
```
|
||||
|
||||
The package is cohesive to Golang `io`. Particularly it implements
|
||||
`io.ReaderFrom` and `io.WriterTo` interfaces. That allows efficient data
|
||||
transmission without unnecessary memory copying and allocations.
|
||||
|
||||
|
||||
TFTP Server
|
||||
-----------
|
||||
|
||||
```go
|
||||
|
||||
// readHandler is called when client starts file download from server
|
||||
func readHandler(filename string, rf io.ReaderFrom) error {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return err
|
||||
}
|
||||
n, err := rf.ReadFrom(file)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%d bytes sent\n", n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeHandler is called when client starts file upload to server
|
||||
func writeHandler(filename string, wt io.WriterTo) error {
|
||||
file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return err
|
||||
}
|
||||
n, err := wt.WriteTo(file)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%d bytes received\n", n)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
// use nil in place of handler to disable read or write operations
|
||||
s := tftp.NewServer(readHandler, writeHandler)
|
||||
s.SetTimeout(5 * time.Second) // optional
|
||||
err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stdout, "server: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
TFTP Client
|
||||
-----------
|
||||
Upload file to server:
|
||||
|
||||
```go
|
||||
c, err := tftp.NewClient("172.16.4.21:69")
|
||||
file, err := os.Open(path)
|
||||
c.SetTimeout(5 * time.Second) // optional
|
||||
rf, err := c.Send("foobar.txt", "octet")
|
||||
n, err := rf.ReadFrom(file)
|
||||
fmt.Printf("%d bytes sent\n", n)
|
||||
```
|
||||
|
||||
Download file from server:
|
||||
|
||||
```go
|
||||
c, err := tftp.NewClient("172.16.4.21:69")
|
||||
wt, err := c.Receive("foobar.txt", "octet")
|
||||
file, err := os.Create(path)
|
||||
// Optionally obtain transfer size before actual data.
|
||||
if n, ok := wt.(IncomingTransfer).Size(); ok {
|
||||
fmt.Printf("Transfer size: %d\n", n)
|
||||
}
|
||||
n, err := wt.WriteTo(file)
|
||||
fmt.Printf("%d bytes received\n", n)
|
||||
```
|
||||
|
||||
Note: please handle errors better :)
|
||||
|
||||
TSize option
|
||||
------------
|
||||
|
||||
PXE boot ROM often expects tsize option support from a server: client
|
||||
(e.g. computer that boots over the network) wants to know size of a
|
||||
download before the actual data comes. Server has to obtain stream
|
||||
size and send it to a client.
|
||||
|
||||
Often it will happen automatically because TFTP library tries to check
|
||||
if `io.Reader` provided to `ReadFrom` method also satisfies
|
||||
`io.Seeker` interface (`os.File` for instance) and uses `Seek` to
|
||||
determine file size.
|
||||
|
||||
In case `io.Reader` you provide to `ReadFrom` in read handler does not
|
||||
satisfy `io.Seeker` interface or you do not want TFTP library to call
|
||||
`Seek` on your reader but still want to respond with tsize option
|
||||
during outgoing request you can use an `OutgoingTransfer` interface:
|
||||
|
||||
```go
|
||||
|
||||
func readHandler(filename string, rf io.ReaderFrom) error {
|
||||
...
|
||||
// Set transfer size before calling ReadFrom.
|
||||
rf.(tftp.OutgoingTransfer).SetSize(myFileSize)
|
||||
...
|
||||
// ReadFrom ...
|
||||
|
||||
```
|
||||
|
||||
Similarly, it is possible to obtain size of a file that is about to be
|
||||
received using `IncomingTransfer` interface (see `Size` method).
|
||||
|
||||
Remote Address
|
||||
--------------
|
||||
|
||||
The `OutgoingTransfer` and `IncomingTransfer` interfaces also provide the
|
||||
`RemoteAddr` method which returns the peer IP address and port as a
|
||||
`net.UDPAddr`. This can be used for detailed logging in a server handler.
|
||||
|
||||
```go
|
||||
|
||||
func readHandler(filename string, rf io.ReaderFrom) error {
|
||||
...
|
||||
raddr := rf.(tftp.OutgoingTransfer).RemoteAddr()
|
||||
log.Println("RRQ from", raddr.String())
|
||||
...
|
||||
// ReadFrom ...
|
||||
```
|
||||
|
||||
Backoff
|
||||
-------
|
||||
|
||||
The default backoff before retransmitting an unacknowledged packet is a
|
||||
random duration between 0 and 1 second. This behavior can be overridden
|
||||
in clients and servers by providing a custom backoff calculation function.
|
||||
|
||||
```go
|
||||
s := tftp.NewServer(readHandler, writeHandler)
|
||||
s.SetBackoff(func (attempts int) time.Duration {
|
||||
return time.Duration(attempts) * time.Second
|
||||
})
|
||||
```
|
||||
|
||||
or, for no backoff
|
||||
|
||||
```go
|
||||
s.SetBackoff(func (int) time.Duration { return 0 })
|
||||
```
|
35
vendor/github.com/pin/tftp/backoff.go
generated
vendored
Normal file
35
vendor/github.com/pin/tftp/backoff.go
generated
vendored
Normal file
@ -0,0 +1,35 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
defaultRetries = 5
|
||||
)
|
||||
|
||||
type backoffFunc func(int) time.Duration
|
||||
|
||||
type backoff struct {
|
||||
attempt int
|
||||
handler backoffFunc
|
||||
}
|
||||
|
||||
func (b *backoff) reset() {
|
||||
b.attempt = 0
|
||||
}
|
||||
|
||||
func (b *backoff) count() int {
|
||||
return b.attempt
|
||||
}
|
||||
|
||||
func (b *backoff) backoff() {
|
||||
if b.handler == nil {
|
||||
time.Sleep(time.Duration(rand.Int63n(int64(time.Second))))
|
||||
} else {
|
||||
time.Sleep(b.handler(b.attempt))
|
||||
}
|
||||
b.attempt++
|
||||
}
|
125
vendor/github.com/pin/tftp/client.go
generated
vendored
Normal file
125
vendor/github.com/pin/tftp/client.go
generated
vendored
Normal file
@ -0,0 +1,125 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewClient creates TFTP client for server on address provided.
|
||||
func NewClient(addr string) (*Client, error) {
|
||||
a, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving address %s: %v", addr, err)
|
||||
}
|
||||
return &Client{
|
||||
addr: a,
|
||||
timeout: defaultTimeout,
|
||||
retries: defaultRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetTimeout sets maximum time client waits for single network round-trip to succeed.
|
||||
// Default is 5 seconds.
|
||||
func (c *Client) SetTimeout(t time.Duration) {
|
||||
if t <= 0 {
|
||||
c.timeout = defaultTimeout
|
||||
}
|
||||
c.timeout = t
|
||||
}
|
||||
|
||||
// SetRetries sets maximum number of attempts client made to transmit a packet.
|
||||
// Default is 5 attempts.
|
||||
func (c *Client) SetRetries(count int) {
|
||||
if count < 1 {
|
||||
c.retries = defaultRetries
|
||||
}
|
||||
c.retries = count
|
||||
}
|
||||
|
||||
// SetBackoff sets a user provided function that is called to provide a
|
||||
// backoff duration prior to retransmitting an unacknowledged packet.
|
||||
func (c *Client) SetBackoff(h backoffFunc) {
|
||||
c.backoff = h
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
addr *net.UDPAddr
|
||||
timeout time.Duration
|
||||
retries int
|
||||
backoff backoffFunc
|
||||
blksize int
|
||||
tsize bool
|
||||
}
|
||||
|
||||
// Send starts outgoing file transmission. It returns io.ReaderFrom or error.
|
||||
func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &sender{
|
||||
send: make([]byte, datagramLength),
|
||||
receive: make([]byte, datagramLength),
|
||||
conn: conn,
|
||||
retry: &backoff{handler: c.backoff},
|
||||
timeout: c.timeout,
|
||||
retries: c.retries,
|
||||
addr: c.addr,
|
||||
mode: mode,
|
||||
}
|
||||
if c.blksize != 0 {
|
||||
s.opts = make(options)
|
||||
s.opts["blksize"] = strconv.Itoa(c.blksize)
|
||||
}
|
||||
n := packRQ(s.send, opWRQ, filename, mode, s.opts)
|
||||
addr, err := s.sendWithRetry(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.addr = addr
|
||||
s.opts = nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Receive starts incoming file transmission. It returns io.WriterTo or error.
|
||||
func (c Client) Receive(filename string, mode string) (io.WriterTo, error) {
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.timeout == 0 {
|
||||
c.timeout = defaultTimeout
|
||||
}
|
||||
r := &receiver{
|
||||
send: make([]byte, datagramLength),
|
||||
receive: make([]byte, datagramLength),
|
||||
conn: conn,
|
||||
retry: &backoff{handler: c.backoff},
|
||||
timeout: c.timeout,
|
||||
retries: c.retries,
|
||||
addr: c.addr,
|
||||
autoTerm: true,
|
||||
block: 1,
|
||||
mode: mode,
|
||||
}
|
||||
if c.blksize != 0 || c.tsize {
|
||||
r.opts = make(options)
|
||||
}
|
||||
if c.blksize != 0 {
|
||||
r.opts["blksize"] = strconv.Itoa(c.blksize)
|
||||
}
|
||||
if c.tsize {
|
||||
r.opts["tsize"] = "0"
|
||||
}
|
||||
n := packRQ(r.send, opRRQ, filename, mode, r.opts)
|
||||
l, addr, err := r.receiveWithRetry(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.l = l
|
||||
r.addr = addr
|
||||
return r, nil
|
||||
}
|
108
vendor/github.com/pin/tftp/netascii/netascii.go
generated
vendored
Normal file
108
vendor/github.com/pin/tftp/netascii/netascii.go
generated
vendored
Normal file
@ -0,0 +1,108 @@
|
||||
package netascii
|
||||
|
||||
// TODO: make it work not only on linux
|
||||
|
||||
import "io"
|
||||
|
||||
const (
|
||||
CR = '\x0d'
|
||||
LF = '\x0a'
|
||||
NUL = '\x00'
|
||||
)
|
||||
|
||||
func ToReader(r io.Reader) io.Reader {
|
||||
return &toReader{
|
||||
r: r,
|
||||
buf: make([]byte, 256),
|
||||
}
|
||||
}
|
||||
|
||||
type toReader struct {
|
||||
r io.Reader
|
||||
buf []byte
|
||||
n int
|
||||
i int
|
||||
err error
|
||||
lf bool
|
||||
nul bool
|
||||
}
|
||||
|
||||
func (r *toReader) Read(p []byte) (int, error) {
|
||||
var n int
|
||||
for n < len(p) {
|
||||
if r.lf {
|
||||
p[n] = LF
|
||||
n++
|
||||
r.lf = false
|
||||
continue
|
||||
}
|
||||
if r.nul {
|
||||
p[n] = NUL
|
||||
n++
|
||||
r.nul = false
|
||||
continue
|
||||
}
|
||||
if r.i < r.n {
|
||||
if r.buf[r.i] == LF {
|
||||
p[n] = CR
|
||||
r.lf = true
|
||||
} else if r.buf[r.i] == CR {
|
||||
p[n] = CR
|
||||
r.nul = true
|
||||
|
||||
} else {
|
||||
p[n] = r.buf[r.i]
|
||||
}
|
||||
r.i++
|
||||
n++
|
||||
continue
|
||||
}
|
||||
if r.err == nil {
|
||||
r.n, r.err = r.r.Read(r.buf)
|
||||
r.i = 0
|
||||
} else {
|
||||
return n, r.err
|
||||
}
|
||||
}
|
||||
return n, r.err
|
||||
}
|
||||
|
||||
type fromWriter struct {
|
||||
w io.Writer
|
||||
buf []byte
|
||||
i int
|
||||
cr bool
|
||||
}
|
||||
|
||||
func FromWriter(w io.Writer) io.Writer {
|
||||
return &fromWriter{
|
||||
w: w,
|
||||
buf: make([]byte, 256),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *fromWriter) Write(p []byte) (n int, err error) {
|
||||
for n < len(p) {
|
||||
if w.cr {
|
||||
if p[n] == LF {
|
||||
w.buf[w.i] = LF
|
||||
}
|
||||
if p[n] == NUL {
|
||||
w.buf[w.i] = CR
|
||||
}
|
||||
w.cr = false
|
||||
w.i++
|
||||
} else if p[n] == CR {
|
||||
w.cr = true
|
||||
} else {
|
||||
w.buf[w.i] = p[n]
|
||||
w.i++
|
||||
}
|
||||
n++
|
||||
if w.i == len(w.buf) || n == len(p) {
|
||||
_, err = w.w.Write(w.buf[:w.i])
|
||||
w.i = 0
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
89
vendor/github.com/pin/tftp/netascii/netascii_test.go
generated
vendored
Normal file
89
vendor/github.com/pin/tftp/netascii/netascii_test.go
generated
vendored
Normal file
@ -0,0 +1,89 @@
|
||||
package netascii
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
)
|
||||
|
||||
var basic = map[string]string{
|
||||
"\r": "\r\x00",
|
||||
"\n": "\r\n",
|
||||
"la\nbu": "la\r\nbu",
|
||||
"la\rbu": "la\r\x00bu",
|
||||
"\r\r\r": "\r\x00\r\x00\r\x00",
|
||||
"\n\n\n": "\r\n\r\n\r\n",
|
||||
}
|
||||
|
||||
func TestTo(t *testing.T) {
|
||||
for text, netascii := range basic {
|
||||
to := ToReader(strings.NewReader(text))
|
||||
n, _ := ioutil.ReadAll(to)
|
||||
if bytes.Compare(n, []byte(netascii)) != 0 {
|
||||
t.Errorf("%q to netascii: %q != %q", text, n, netascii)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrom(t *testing.T) {
|
||||
for text, netascii := range basic {
|
||||
r := bytes.NewReader([]byte(netascii))
|
||||
b := &bytes.Buffer{}
|
||||
from := FromWriter(b)
|
||||
r.WriteTo(from)
|
||||
n, _ := ioutil.ReadAll(b)
|
||||
if string(n) != text {
|
||||
t.Errorf("%q from netascii: %q != %q", netascii, n, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const text = `
|
||||
Therefore, the sequence "CR LF" must be treated as a single "new
|
||||
line" character and used whenever their combined action is
|
||||
intended; the sequence "CR NUL" must be used where a carriage
|
||||
return alone is actually desired; and the CR character must be
|
||||
avoided in other contexts. This rule gives assurance to systems
|
||||
which must decide whether to perform a "new line" function or a
|
||||
multiple-backspace that the TELNET stream contains a character
|
||||
following a CR that will allow a rational decision.
|
||||
(in the default ASCII mode), to preserve the symmetry of the
|
||||
NVT model. Even though it may be known in some situations
|
||||
(e.g., with remote echo and suppress go ahead options in
|
||||
effect) that characters are not being sent to an actual
|
||||
printer, nonetheless, for the sake of consistency, the protocol
|
||||
requires that a NUL be inserted following a CR not followed by
|
||||
a LF in the data stream. The converse of this is that a NUL
|
||||
received in the data stream after a CR (in the absence of
|
||||
options negotiations which explicitly specify otherwise) should
|
||||
be stripped out prior to applying the NVT to local character
|
||||
set mapping.
|
||||
`
|
||||
|
||||
func TestWriteRead(t *testing.T) {
|
||||
var one bytes.Buffer
|
||||
to := ToReader(strings.NewReader(text))
|
||||
one.ReadFrom(to)
|
||||
two := &bytes.Buffer{}
|
||||
from := FromWriter(two)
|
||||
one.WriteTo(from)
|
||||
text2, _ := ioutil.ReadAll(two)
|
||||
if text != string(text2) {
|
||||
t.Errorf("text mismatch \n%x \n%x", text, text2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOneByte(t *testing.T) {
|
||||
var one bytes.Buffer
|
||||
to := iotest.OneByteReader(ToReader(strings.NewReader(text)))
|
||||
one.ReadFrom(to)
|
||||
two := &bytes.Buffer{}
|
||||
from := FromWriter(two)
|
||||
one.WriteTo(from)
|
||||
text2, _ := ioutil.ReadAll(two)
|
||||
if text != string(text2) {
|
||||
t.Errorf("text mismatch \n%x \n%x", text, text2)
|
||||
}
|
||||
}
|
190
vendor/github.com/pin/tftp/packet.go
generated
vendored
Normal file
190
vendor/github.com/pin/tftp/packet.go
generated
vendored
Normal file
@ -0,0 +1,190 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
opRRQ = uint16(1) // Read request (RRQ)
|
||||
opWRQ = uint16(2) // Write request (WRQ)
|
||||
opDATA = uint16(3) // Data
|
||||
opACK = uint16(4) // Acknowledgement
|
||||
opERROR = uint16(5) // Error
|
||||
opOACK = uint16(6) // Options Acknowledgment
|
||||
)
|
||||
|
||||
const (
|
||||
blockLength = 512
|
||||
datagramLength = 516
|
||||
)
|
||||
|
||||
type options map[string]string
|
||||
|
||||
// RRQ/WRQ packet
|
||||
//
|
||||
// 2 bytes string 1 byte string 1 byte
|
||||
// --------------------------------------------------
|
||||
// | Opcode | Filename | 0 | Mode | 0 |
|
||||
// --------------------------------------------------
|
||||
type pRRQ []byte
|
||||
type pWRQ []byte
|
||||
|
||||
// packRQ returns length of the packet in b
|
||||
func packRQ(p []byte, op uint16, filename, mode string, opts options) int {
|
||||
binary.BigEndian.PutUint16(p, op)
|
||||
n := 2
|
||||
n += copy(p[2:len(p)-10], filename)
|
||||
p[n] = 0
|
||||
n++
|
||||
n += copy(p[n:], mode)
|
||||
p[n] = 0
|
||||
n++
|
||||
for name, value := range opts {
|
||||
n += copy(p[n:], name)
|
||||
p[n] = 0
|
||||
n++
|
||||
n += copy(p[n:], value)
|
||||
p[n] = 0
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func unpackRQ(p []byte) (filename, mode string, opts options, err error) {
|
||||
bs := bytes.Split(p[2:], []byte{0})
|
||||
if len(bs) < 2 {
|
||||
return "", "", nil, fmt.Errorf("missing filename or mode")
|
||||
}
|
||||
filename = string(bs[0])
|
||||
mode = string(bs[1])
|
||||
if len(bs) < 4 {
|
||||
return filename, mode, nil, nil
|
||||
}
|
||||
opts = make(options)
|
||||
for i := 2; i+1 < len(bs); i += 2 {
|
||||
opts[string(bs[i])] = string(bs[i+1])
|
||||
}
|
||||
return filename, mode, opts, nil
|
||||
}
|
||||
|
||||
// OACK packet
|
||||
//
|
||||
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
|
||||
// | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
|
||||
// +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
|
||||
type pOACK []byte
|
||||
|
||||
func packOACK(p []byte, opts options) int {
|
||||
binary.BigEndian.PutUint16(p, opOACK)
|
||||
n := 2
|
||||
for name, value := range opts {
|
||||
n += copy(p[n:], name)
|
||||
p[n] = 0
|
||||
n++
|
||||
n += copy(p[n:], value)
|
||||
p[n] = 0
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func unpackOACK(p []byte) (opts options, err error) {
|
||||
bs := bytes.Split(p[2:], []byte{0})
|
||||
opts = make(options)
|
||||
for i := 0; i+1 < len(bs); i += 2 {
|
||||
opts[string(bs[i])] = string(bs[i+1])
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// ERROR packet
|
||||
//
|
||||
// 2 bytes 2 bytes string 1 byte
|
||||
// ------------------------------------------
|
||||
// | Opcode | ErrorCode | ErrMsg | 0 |
|
||||
// ------------------------------------------
|
||||
type pERROR []byte
|
||||
|
||||
func packERROR(p []byte, code uint16, message string) int {
|
||||
binary.BigEndian.PutUint16(p, opERROR)
|
||||
binary.BigEndian.PutUint16(p[2:], code)
|
||||
n := copy(p[4:len(p)-2], message)
|
||||
p[4+n] = 0
|
||||
return n + 5
|
||||
}
|
||||
|
||||
func (p pERROR) code() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func (p pERROR) message() string {
|
||||
return string(p[4:])
|
||||
}
|
||||
|
||||
// DATA packet
|
||||
//
|
||||
// 2 bytes 2 bytes n bytes
|
||||
// ----------------------------------
|
||||
// | Opcode | Block # | Data |
|
||||
// ----------------------------------
|
||||
type pDATA []byte
|
||||
|
||||
func (p pDATA) block() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
// ACK packet
|
||||
//
|
||||
// 2 bytes 2 bytes
|
||||
// -----------------------
|
||||
// | Opcode | Block # |
|
||||
// -----------------------
|
||||
type pACK []byte
|
||||
|
||||
func (p pACK) block() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func parsePacket(p []byte) (interface{}, error) {
|
||||
l := len(p)
|
||||
if l < 2 {
|
||||
return nil, fmt.Errorf("short packet")
|
||||
}
|
||||
opcode := binary.BigEndian.Uint16(p)
|
||||
switch opcode {
|
||||
case opRRQ:
|
||||
if l < 4 {
|
||||
return nil, fmt.Errorf("short RRQ packet: %d", l)
|
||||
}
|
||||
return pRRQ(p), nil
|
||||
case opWRQ:
|
||||
if l < 4 {
|
||||
return nil, fmt.Errorf("short WRQ packet: %d", l)
|
||||
}
|
||||
return pWRQ(p), nil
|
||||
case opDATA:
|
||||
if l < 4 {
|
||||
return nil, fmt.Errorf("short DATA packet: %d", l)
|
||||
}
|
||||
return pDATA(p), nil
|
||||
case opACK:
|
||||
if l < 4 {
|
||||
return nil, fmt.Errorf("short ACK packet: %d", l)
|
||||
}
|
||||
return pACK(p), nil
|
||||
case opERROR:
|
||||
if l < 5 {
|
||||
return nil, fmt.Errorf("short ERROR packet: %d", l)
|
||||
}
|
||||
return pERROR(p), nil
|
||||
case opOACK:
|
||||
if l < 6 {
|
||||
return nil, fmt.Errorf("short OACK packet: %d", l)
|
||||
}
|
||||
return pOACK(p), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown opcode: %d", opcode)
|
||||
}
|
||||
}
|
234
vendor/github.com/pin/tftp/receiver.go
generated
vendored
Normal file
234
vendor/github.com/pin/tftp/receiver.go
generated
vendored
Normal file
@ -0,0 +1,234 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pin/tftp/netascii"
|
||||
)
|
||||
|
||||
// IncomingTransfer provides methods that expose information associated with
|
||||
// an incoming transfer.
|
||||
type IncomingTransfer interface {
|
||||
// Size returns the size of an incoming file if the request included the
|
||||
// tsize option (see RFC2349). To differentiate a zero-sized file transfer
|
||||
// from a request without tsize use the second boolean "ok" return value.
|
||||
Size() (n int64, ok bool)
|
||||
|
||||
// RemoteAddr returns the remote peer's IP address and port.
|
||||
RemoteAddr() net.UDPAddr
|
||||
}
|
||||
|
||||
func (r *receiver) RemoteAddr() net.UDPAddr { return *r.addr }
|
||||
|
||||
func (r *receiver) Size() (n int64, ok bool) {
|
||||
if r.opts != nil {
|
||||
if s, ok := r.opts["tsize"]; ok {
|
||||
n, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
type receiver struct {
|
||||
send []byte
|
||||
receive []byte
|
||||
addr *net.UDPAddr
|
||||
tid int
|
||||
conn *net.UDPConn
|
||||
block uint16
|
||||
retry *backoff
|
||||
timeout time.Duration
|
||||
retries int
|
||||
l int
|
||||
autoTerm bool
|
||||
dally bool
|
||||
mode string
|
||||
opts options
|
||||
}
|
||||
|
||||
func (r *receiver) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if r.mode == "netascii" {
|
||||
w = netascii.FromWriter(w)
|
||||
}
|
||||
if r.opts != nil {
|
||||
err := r.sendOptions()
|
||||
if err != nil {
|
||||
r.abort(err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
binary.BigEndian.PutUint16(r.send[0:2], opACK)
|
||||
for {
|
||||
if r.l > 0 {
|
||||
l, err := w.Write(r.receive[4:r.l])
|
||||
n += int64(l)
|
||||
if err != nil {
|
||||
r.abort(err)
|
||||
return n, err
|
||||
}
|
||||
if r.l < len(r.receive) {
|
||||
if r.autoTerm {
|
||||
r.terminate()
|
||||
r.conn.Close()
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
binary.BigEndian.PutUint16(r.send[2:4], r.block)
|
||||
r.block++ // send ACK for current block and expect next one
|
||||
ll, _, err := r.receiveWithRetry(4)
|
||||
if err != nil {
|
||||
r.abort(err)
|
||||
return n, err
|
||||
}
|
||||
r.l = ll
|
||||
}
|
||||
}
|
||||
|
||||
func (r *receiver) sendOptions() error {
|
||||
for name, value := range r.opts {
|
||||
if name == "blksize" {
|
||||
err := r.setBlockSize(value)
|
||||
if err != nil {
|
||||
delete(r.opts, name)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
delete(r.opts, name)
|
||||
}
|
||||
}
|
||||
if len(r.opts) > 0 {
|
||||
m := packOACK(r.send, r.opts)
|
||||
r.block = 1 // expect data block number 1
|
||||
ll, _, err := r.receiveWithRetry(m)
|
||||
if err != nil {
|
||||
r.abort(err)
|
||||
return err
|
||||
}
|
||||
r.l = ll
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *receiver) setBlockSize(blksize string) error {
|
||||
n, err := strconv.Atoi(blksize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < 512 {
|
||||
return fmt.Errorf("blkzise too small: %d", n)
|
||||
}
|
||||
if n > 65464 {
|
||||
return fmt.Errorf("blksize too large: %d", n)
|
||||
}
|
||||
r.receive = make([]byte, n+4)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) {
|
||||
r.retry.reset()
|
||||
for {
|
||||
n, addr, err := r.receiveDatagram(l)
|
||||
if _, ok := err.(net.Error); ok && r.retry.count() < r.retries {
|
||||
r.retry.backoff()
|
||||
continue
|
||||
}
|
||||
return n, addr, err
|
||||
}
|
||||
}
|
||||
|
||||
func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) {
|
||||
err := r.conn.SetReadDeadline(time.Now().Add(r.timeout))
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
_, err = r.conn.WriteToUDP(r.send[:l], r.addr)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
for {
|
||||
c, addr, err := r.conn.ReadFromUDP(r.receive)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) {
|
||||
continue
|
||||
}
|
||||
p, err := parsePacket(r.receive[:c])
|
||||
if err != nil {
|
||||
return 0, addr, err
|
||||
}
|
||||
r.tid = addr.Port
|
||||
switch p := p.(type) {
|
||||
case pDATA:
|
||||
if p.block() == r.block {
|
||||
return c, addr, nil
|
||||
}
|
||||
case pOACK:
|
||||
opts, err := unpackOACK(p)
|
||||
if r.block != 1 {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
r.abort(err)
|
||||
return 0, addr, err
|
||||
}
|
||||
for name, value := range opts {
|
||||
if name == "blksize" {
|
||||
err := r.setBlockSize(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
r.block = 0 // ACK with block number 0
|
||||
r.opts = opts
|
||||
return 0, addr, nil
|
||||
case pERROR:
|
||||
return 0, addr, fmt.Errorf("code: %d, message: %s",
|
||||
p.code(), p.message())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *receiver) terminate() error {
|
||||
binary.BigEndian.PutUint16(r.send[2:4], r.block)
|
||||
if r.dally {
|
||||
for i := 0; i < 3; i++ {
|
||||
_, _, err := r.receiveDatagram(4)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("dallying termination failed")
|
||||
} else {
|
||||
_, err := r.conn.WriteToUDP(r.send[:4], r.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *receiver) abort(err error) error {
|
||||
if r.conn == nil {
|
||||
return nil
|
||||
}
|
||||
n := packERROR(r.send, 1, err.Error())
|
||||
_, err = r.conn.WriteToUDP(r.send[:n], r.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.conn.Close()
|
||||
r.conn = nil
|
||||
return nil
|
||||
}
|
243
vendor/github.com/pin/tftp/sender.go
generated
vendored
Normal file
243
vendor/github.com/pin/tftp/sender.go
generated
vendored
Normal file
@ -0,0 +1,243 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pin/tftp/netascii"
|
||||
)
|
||||
|
||||
// OutgoingTransfer provides methods to set the outgoing transfer size and
|
||||
// retrieve the remote address of the peer.
|
||||
type OutgoingTransfer interface {
|
||||
// SetSize is used to set the outgoing transfer size (tsize option: RFC2349)
|
||||
// manually in a server write transfer handler.
|
||||
//
|
||||
// It is not necessary in most cases; when the io.Reader provided to
|
||||
// ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will
|
||||
// be determined automatically. Seek will not be attempted when the
|
||||
// transfer size option is set with SetSize.
|
||||
//
|
||||
// The value provided will be used only if SetSize is called before ReadFrom
|
||||
// and only on in a server read handler.
|
||||
SetSize(n int64)
|
||||
|
||||
// RemoteAddr returns the remote peer's IP address and port.
|
||||
RemoteAddr() net.UDPAddr
|
||||
}
|
||||
|
||||
type sender struct {
|
||||
conn *net.UDPConn
|
||||
addr *net.UDPAddr
|
||||
tid int
|
||||
send []byte
|
||||
receive []byte
|
||||
retry *backoff
|
||||
timeout time.Duration
|
||||
retries int
|
||||
block uint16
|
||||
mode string
|
||||
opts options
|
||||
}
|
||||
|
||||
func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr }
|
||||
|
||||
func (s *sender) SetSize(n int64) {
|
||||
if s.opts != nil {
|
||||
if _, ok := s.opts["tsize"]; ok {
|
||||
s.opts["tsize"] = strconv.FormatInt(n, 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sender) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if s.mode == "netascii" {
|
||||
r = netascii.ToReader(r)
|
||||
}
|
||||
if s.opts != nil {
|
||||
// check that tsize is set
|
||||
if ts, ok := s.opts["tsize"]; ok {
|
||||
// check that tsize is not set with SetSize already
|
||||
i, err := strconv.ParseInt(ts, 10, 64)
|
||||
if err == nil && i == 0 {
|
||||
if rs, ok := r.(io.Seeker); ok {
|
||||
pos, err := rs.Seek(0, 1)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
size, err := rs.Seek(0, 2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.opts["tsize"] = strconv.FormatInt(size, 10)
|
||||
_, err = rs.Seek(pos, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
err = s.sendOptions()
|
||||
if err != nil {
|
||||
s.abort(err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
s.block = 1 // start data transmission with block 1
|
||||
binary.BigEndian.PutUint16(s.send[0:2], opDATA)
|
||||
for {
|
||||
l, err := io.ReadFull(r, s.send[4:])
|
||||
n += int64(l)
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
if err == io.EOF {
|
||||
binary.BigEndian.PutUint16(s.send[2:4], s.block)
|
||||
_, err = s.sendWithRetry(4)
|
||||
if err != nil {
|
||||
s.abort(err)
|
||||
return n, err
|
||||
}
|
||||
s.conn.Close()
|
||||
return n, nil
|
||||
}
|
||||
s.abort(err)
|
||||
return n, err
|
||||
}
|
||||
binary.BigEndian.PutUint16(s.send[2:4], s.block)
|
||||
_, err = s.sendWithRetry(4 + l)
|
||||
if err != nil {
|
||||
s.abort(err)
|
||||
return n, err
|
||||
}
|
||||
if l < len(s.send)-4 {
|
||||
s.conn.Close()
|
||||
return n, nil
|
||||
}
|
||||
s.block++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sender) sendOptions() error {
|
||||
for name, value := range s.opts {
|
||||
if name == "blksize" {
|
||||
err := s.setBlockSize(value)
|
||||
if err != nil {
|
||||
delete(s.opts, name)
|
||||
continue
|
||||
}
|
||||
} else if name == "tsize" {
|
||||
if value != "0" {
|
||||
s.opts["tsize"] = value
|
||||
} else {
|
||||
delete(s.opts, name)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
delete(s.opts, name)
|
||||
}
|
||||
}
|
||||
if len(s.opts) > 0 {
|
||||
m := packOACK(s.send, s.opts)
|
||||
_, err := s.sendWithRetry(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sender) setBlockSize(blksize string) error {
|
||||
n, err := strconv.Atoi(blksize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < 512 {
|
||||
return fmt.Errorf("blkzise too small: %d", n)
|
||||
}
|
||||
if n > 65464 {
|
||||
return fmt.Errorf("blksize too large: %d", n)
|
||||
}
|
||||
s.send = make([]byte, n+4)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) {
|
||||
s.retry.reset()
|
||||
for {
|
||||
addr, err := s.sendDatagram(l)
|
||||
if _, ok := err.(net.Error); ok && s.retry.count() < s.retries {
|
||||
s.retry.backoff()
|
||||
continue
|
||||
}
|
||||
return addr, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) {
|
||||
err := s.conn.SetReadDeadline(time.Now().Add(s.timeout))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = s.conn.WriteToUDP(s.send[:l], s.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
n, addr, err := s.conn.ReadFromUDP(s.receive)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) {
|
||||
continue
|
||||
}
|
||||
p, err := parsePacket(s.receive[:n])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
s.tid = addr.Port
|
||||
switch p := p.(type) {
|
||||
case pACK:
|
||||
if p.block() == s.block {
|
||||
return addr, nil
|
||||
}
|
||||
case pOACK:
|
||||
opts, err := unpackOACK(p)
|
||||
if s.block != 0 {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
s.abort(err)
|
||||
return addr, err
|
||||
}
|
||||
for name, value := range opts {
|
||||
if name == "blksize" {
|
||||
err := s.setBlockSize(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
return addr, nil
|
||||
case pERROR:
|
||||
return nil, fmt.Errorf("sending block %d: code=%d, error: %s",
|
||||
s.block, p.code(), p.message())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sender) abort(err error) error {
|
||||
if s.conn == nil {
|
||||
return nil
|
||||
}
|
||||
n := packERROR(s.send, 1, err.Error())
|
||||
_, err = s.conn.WriteToUDP(s.send[:n], s.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
return nil
|
||||
}
|
199
vendor/github.com/pin/tftp/server.go
generated
vendored
Normal file
199
vendor/github.com/pin/tftp/server.go
generated
vendored
Normal file
@ -0,0 +1,199 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewServer creates TFTP server. It requires two functions to handle
|
||||
// read and write requests.
|
||||
// In case nil is provided for read or write handler the respective
|
||||
// operation is disabled.
|
||||
func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
|
||||
writeHandler func(filename string, wt io.WriterTo) error) *Server {
|
||||
return &Server{
|
||||
readHandler: readHandler,
|
||||
writeHandler: writeHandler,
|
||||
timeout: defaultTimeout,
|
||||
retries: defaultRetries,
|
||||
}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
readHandler func(filename string, rf io.ReaderFrom) error
|
||||
writeHandler func(filename string, wt io.WriterTo) error
|
||||
backoff backoffFunc
|
||||
conn *net.UDPConn
|
||||
quit chan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
timeout time.Duration
|
||||
retries int
|
||||
}
|
||||
|
||||
// SetTimeout sets maximum time server waits for single network
|
||||
// round-trip to succeed.
|
||||
// Default is 5 seconds.
|
||||
func (s *Server) SetTimeout(t time.Duration) {
|
||||
if t <= 0 {
|
||||
s.timeout = defaultTimeout
|
||||
} else {
|
||||
s.timeout = t
|
||||
}
|
||||
}
|
||||
|
||||
// SetRetries sets maximum number of attempts server made to transmit a
|
||||
// packet.
|
||||
// Default is 5 attempts.
|
||||
func (s *Server) SetRetries(count int) {
|
||||
if count < 1 {
|
||||
s.retries = defaultRetries
|
||||
} else {
|
||||
s.retries = count
|
||||
}
|
||||
}
|
||||
|
||||
// SetBackoff sets a user provided function that is called to provide a
|
||||
// backoff duration prior to retransmitting an unacknowledged packet.
|
||||
func (s *Server) SetBackoff(h backoffFunc) {
|
||||
s.backoff = h
|
||||
}
|
||||
|
||||
// ListenAndServe binds to address provided and start the server.
|
||||
// ListenAndServe returns when Shutdown is called.
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
a, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Serve(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve starts server provided already opened UDP connecton. It is
|
||||
// useful for the case when you want to run server in separate goroutine
|
||||
// but still want to be able to handle any errors opening connection.
|
||||
// Serve returns when Shutdown is called or connection is closed.
|
||||
func (s *Server) Serve(conn *net.UDPConn) {
|
||||
s.conn = conn
|
||||
s.quit = make(chan chan struct{})
|
||||
for {
|
||||
select {
|
||||
case q := <-s.quit:
|
||||
q <- struct{}{}
|
||||
return
|
||||
default:
|
||||
err := s.processRequest(s.conn)
|
||||
if err != nil {
|
||||
// TODO: add logging handler
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown make server stop listening for new requests, allows
|
||||
// server to finish outstanding transfers and stops server.
|
||||
func (s *Server) Shutdown() {
|
||||
s.conn.Close()
|
||||
q := make(chan struct{})
|
||||
s.quit <- q
|
||||
<-q
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) processRequest(conn *net.UDPConn) error {
|
||||
var buffer []byte
|
||||
buffer = make([]byte, datagramLength)
|
||||
n, remoteAddr, err := conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading UDP: %v", err)
|
||||
}
|
||||
p, err := parsePacket(buffer[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch p := p.(type) {
|
||||
case pWRQ:
|
||||
filename, mode, opts, err := unpackRQ(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpack WRQ: %v", err)
|
||||
}
|
||||
//fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("open transmission: %v", err)
|
||||
}
|
||||
wt := &receiver{
|
||||
send: make([]byte, datagramLength),
|
||||
receive: make([]byte, datagramLength),
|
||||
conn: conn,
|
||||
retry: &backoff{handler: s.backoff},
|
||||
timeout: s.timeout,
|
||||
retries: s.retries,
|
||||
addr: remoteAddr,
|
||||
mode: mode,
|
||||
opts: opts,
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
if s.writeHandler != nil {
|
||||
err := s.writeHandler(filename, wt)
|
||||
if err != nil {
|
||||
wt.abort(err)
|
||||
} else {
|
||||
wt.terminate()
|
||||
wt.conn.Close()
|
||||
}
|
||||
} else {
|
||||
wt.abort(fmt.Errorf("server does not support write requests"))
|
||||
}
|
||||
s.wg.Done()
|
||||
}()
|
||||
case pRRQ:
|
||||
filename, mode, opts, err := unpackRQ(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpack RRQ: %v", err)
|
||||
}
|
||||
//fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rf := &sender{
|
||||
send: make([]byte, datagramLength),
|
||||
receive: make([]byte, datagramLength),
|
||||
tid: remoteAddr.Port,
|
||||
conn: conn,
|
||||
retry: &backoff{handler: s.backoff},
|
||||
timeout: s.timeout,
|
||||
retries: s.retries,
|
||||
addr: remoteAddr,
|
||||
mode: mode,
|
||||
opts: opts,
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
if s.readHandler != nil {
|
||||
err := s.readHandler(filename, rf)
|
||||
if err != nil {
|
||||
rf.abort(err)
|
||||
}
|
||||
} else {
|
||||
rf.abort(fmt.Errorf("server does not support read requests"))
|
||||
}
|
||||
s.wg.Done()
|
||||
}()
|
||||
default:
|
||||
return fmt.Errorf("unexpected %T", p)
|
||||
}
|
||||
return nil
|
||||
}
|
622
vendor/github.com/pin/tftp/tftp_test.go
generated
vendored
Normal file
622
vendor/github.com/pin/tftp/tftp_test.go
generated
vendored
Normal file
@ -0,0 +1,622 @@
|
||||
package tftp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
"time"
|
||||
)
|
||||
|
||||
var localhost string = determineLocalhost()
|
||||
|
||||
func determineLocalhost() string {
|
||||
l, err := net.ListenTCP("tcp", nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ListenTCP error: %s", err))
|
||||
}
|
||||
_, lport, _ := net.SplitHostPort(l.Addr().String())
|
||||
defer l.Close()
|
||||
|
||||
lo := make(chan string)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
port, _ := strconv.Atoi(lport)
|
||||
for _, af := range []string{"tcp6", "tcp4"} {
|
||||
conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port})
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
host, _, _ := net.SplitHostPort(conn.LocalAddr().String())
|
||||
lo <- host
|
||||
return
|
||||
}
|
||||
}
|
||||
panic("could not determine address family")
|
||||
}()
|
||||
|
||||
return <-lo
|
||||
}
|
||||
|
||||
func localSystem(c *net.UDPConn) string {
|
||||
_, port, _ := net.SplitHostPort(c.LocalAddr().String())
|
||||
return net.JoinHostPort(localhost, port)
|
||||
}
|
||||
|
||||
func TestPackUnpack(t *testing.T) {
|
||||
v := []string{"test-filename/with-subdir"}
|
||||
testOptsList := []options{
|
||||
nil,
|
||||
options{
|
||||
"tsize": "1234",
|
||||
"blksize": "22",
|
||||
},
|
||||
}
|
||||
for _, filename := range v {
|
||||
for _, mode := range []string{"octet", "netascii"} {
|
||||
for _, opts := range testOptsList {
|
||||
packUnpack(t, filename, mode, opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func packUnpack(t *testing.T, filename, mode string, opts options) {
|
||||
b := make([]byte, datagramLength)
|
||||
for _, op := range []uint16{opRRQ, opWRQ} {
|
||||
n := packRQ(b, op, filename, mode, opts)
|
||||
f, m, o, err := unpackRQ(b[:n])
|
||||
if err != nil {
|
||||
t.Errorf("%s pack/unpack: %v", filename, err)
|
||||
}
|
||||
if f != filename {
|
||||
t.Errorf("filename mismatch (%s): '%x' vs '%x'",
|
||||
filename, f, filename)
|
||||
}
|
||||
if m != mode {
|
||||
t.Errorf("mode mismatch (%s): '%x' vs '%x'",
|
||||
mode, m, mode)
|
||||
}
|
||||
if opts != nil {
|
||||
for name, value := range opts {
|
||||
v, ok := o[name]
|
||||
if !ok {
|
||||
t.Errorf("missing %s option", name)
|
||||
}
|
||||
if v != value {
|
||||
t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroLength(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
testSendReceive(t, c, 0)
|
||||
}
|
||||
|
||||
func Test900(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
for i := 600; i < 4000; i += 1 {
|
||||
c.blksize = i
|
||||
testSendReceive(t, c, 9000+int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func Test1000(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
for i := int64(0); i < 5000; i++ {
|
||||
filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano())
|
||||
rf, err := c.Send(filename, "octet")
|
||||
if err != nil {
|
||||
t.Fatalf("requesting %s write: %v", filename, err)
|
||||
}
|
||||
r := io.LimitReader(newRandReader(rand.NewSource(i)), i)
|
||||
n, err := rf.ReadFrom(r)
|
||||
if err != nil {
|
||||
t.Fatalf("sending %s: %v", filename, err)
|
||||
}
|
||||
if n != i {
|
||||
t.Errorf("%s length mismatch: %d != %d", filename, n, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test1810(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
c.blksize = 1810
|
||||
testSendReceive(t, c, 9000+1810)
|
||||
}
|
||||
|
||||
func TestTSize(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
c.tsize = true
|
||||
testSendReceive(t, c, 640)
|
||||
}
|
||||
|
||||
func TestNearBlockLength(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
for i := 450; i < 520; i++ {
|
||||
testSendReceive(t, c, int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlockWrapsAround(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
n := 65535 * 512
|
||||
for i := n - 2; i < n+2; i++ {
|
||||
testSendReceive(t, c, int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomLength(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
r := rand.New(rand.NewSource(42))
|
||||
for i := 0; i < 100; i++ {
|
||||
testSendReceive(t, c, r.Int63n(100000))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBigFile(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
testSendReceive(t, c, 3*1000*1000)
|
||||
}
|
||||
|
||||
func TestByOneByte(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
filename := "test-by-one-byte"
|
||||
mode := "octet"
|
||||
const length = 80000
|
||||
sender, err := c.Send(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting write: %v", err)
|
||||
}
|
||||
r := iotest.OneByteReader(io.LimitReader(
|
||||
newRandReader(rand.NewSource(42)), length))
|
||||
n, err := sender.ReadFrom(r)
|
||||
if err != nil {
|
||||
t.Fatalf("send error: %v", err)
|
||||
}
|
||||
if n != length {
|
||||
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
|
||||
}
|
||||
readTransfer, err := c.Receive(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting read %s: %v", filename, err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
n, err = readTransfer.WriteTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s read error: %v", filename, err)
|
||||
}
|
||||
if n != length {
|
||||
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
|
||||
}
|
||||
bs, _ := ioutil.ReadAll(io.LimitReader(
|
||||
newRandReader(rand.NewSource(42)), length))
|
||||
if !bytes.Equal(bs, buf.Bytes()) {
|
||||
t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicate(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
filename := "test-duplicate"
|
||||
mode := "octet"
|
||||
bs := []byte("lalala")
|
||||
sender, err := c.Send(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting write: %v", err)
|
||||
}
|
||||
buf := bytes.NewBuffer(bs)
|
||||
_, err = sender.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("send error: %v", err)
|
||||
}
|
||||
sender, err = c.Send(filename, mode)
|
||||
if err == nil {
|
||||
t.Fatalf("file already exists")
|
||||
}
|
||||
t.Logf("sending file that already exists: %v", err)
|
||||
}
|
||||
|
||||
func TestNotFound(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
defer s.Shutdown()
|
||||
filename := "test-not-exists"
|
||||
mode := "octet"
|
||||
_, err := c.Receive(filename, mode)
|
||||
if err == nil {
|
||||
t.Fatalf("file not exists", err)
|
||||
}
|
||||
t.Logf("receiving file that does not exist: %v", err)
|
||||
}
|
||||
|
||||
func testSendReceive(t *testing.T, client *Client, length int64) {
|
||||
filename := fmt.Sprintf("length-%d-bytes", length)
|
||||
mode := "octet"
|
||||
writeTransfer, err := client.Send(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting write %s: %v", filename, err)
|
||||
}
|
||||
r := io.LimitReader(newRandReader(rand.NewSource(42)), length)
|
||||
n, err := writeTransfer.ReadFrom(r)
|
||||
if err != nil {
|
||||
t.Fatalf("%s write error: %v", filename, err)
|
||||
}
|
||||
if n != length {
|
||||
t.Errorf("%s write length mismatch: %d != %d", filename, n, length)
|
||||
}
|
||||
readTransfer, err := client.Receive(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting read %s: %v", filename, err)
|
||||
}
|
||||
if it, ok := readTransfer.(IncomingTransfer); ok {
|
||||
if n, ok := it.Size(); ok {
|
||||
fmt.Printf("Transfer size: %d\n", n)
|
||||
if n != length {
|
||||
t.Errorf("tsize mismatch: %d vs %d", n, length)
|
||||
}
|
||||
}
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
n, err = readTransfer.WriteTo(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("%s read error: %v", filename, err)
|
||||
}
|
||||
if n != length {
|
||||
t.Errorf("%s read length mismatch: %d != %d", filename, n, length)
|
||||
}
|
||||
bs, _ := ioutil.ReadAll(io.LimitReader(
|
||||
newRandReader(rand.NewSource(42)), length))
|
||||
if !bytes.Equal(bs, buf.Bytes()) {
|
||||
t.Errorf("\nsent: %x\nrcvd: %x", bs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendTsizeFromSeek(t *testing.T) {
|
||||
// create read-only server
|
||||
s := NewServer(func(filename string, rf io.ReaderFrom) error {
|
||||
b := make([]byte, 100)
|
||||
rr := newRandReader(rand.NewSource(42))
|
||||
rr.Read(b)
|
||||
// bytes.Reader implements io.Seek
|
||||
r := bytes.NewReader(b)
|
||||
_, err := rf.ReadFrom(r)
|
||||
if err != nil {
|
||||
t.Errorf("sending bytes: %v", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
t.Fatalf("listening: %v", err)
|
||||
}
|
||||
|
||||
go s.Serve(conn)
|
||||
defer s.Shutdown()
|
||||
|
||||
c, _ := NewClient(localSystem(conn))
|
||||
c.tsize = true
|
||||
r, _ := c.Receive("f", "octet")
|
||||
var size int64
|
||||
if t, ok := r.(IncomingTransfer); ok {
|
||||
if n, ok := t.Size(); ok {
|
||||
size = n
|
||||
fmt.Printf("Transfer size: %d\n", n)
|
||||
}
|
||||
}
|
||||
|
||||
if size != 100 {
|
||||
t.Errorf("size expected: 100, got %d", size)
|
||||
}
|
||||
|
||||
r.WriteTo(ioutil.Discard)
|
||||
}
|
||||
|
||||
type testBackend struct {
|
||||
m map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func makeTestServer() (*Server, *Client) {
|
||||
b := &testBackend{}
|
||||
b.m = make(map[string][]byte)
|
||||
|
||||
// Create server
|
||||
s := NewServer(b.handleRead, b.handleWrite)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go s.Serve(conn)
|
||||
|
||||
// Create client for that server
|
||||
c, err := NewClient(localSystem(conn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return s, c
|
||||
}
|
||||
|
||||
func TestNoHandlers(t *testing.T) {
|
||||
s := NewServer(nil, nil)
|
||||
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go s.Serve(conn)
|
||||
|
||||
c, err := NewClient(localSystem(conn))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = c.Send("test", "octet")
|
||||
if err == nil {
|
||||
t.Errorf("error expected")
|
||||
}
|
||||
|
||||
_, err = c.Receive("test", "octet")
|
||||
if err == nil {
|
||||
t.Errorf("error expected")
|
||||
}
|
||||
}
|
||||
|
||||
func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
_, ok := b.m[filename]
|
||||
if ok {
|
||||
fmt.Fprintf(os.Stderr, "File %s already exists\n", filename)
|
||||
return fmt.Errorf("file already exists")
|
||||
}
|
||||
if t, ok := wt.(IncomingTransfer); ok {
|
||||
if n, ok := t.Size(); ok {
|
||||
fmt.Printf("Transfer size: %d\n", n)
|
||||
}
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
_, err := wt.WriteTo(buf)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err)
|
||||
return err
|
||||
}
|
||||
b.m[filename] = buf.Bytes()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
bs, ok := b.m[filename]
|
||||
if !ok {
|
||||
fmt.Fprintf(os.Stderr, "File %s not found\n", filename)
|
||||
return fmt.Errorf("file not found")
|
||||
}
|
||||
if t, ok := rf.(OutgoingTransfer); ok {
|
||||
t.SetSize(int64(len(bs)))
|
||||
}
|
||||
_, err := rf.ReadFrom(bytes.NewBuffer(bs))
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type randReader struct {
|
||||
src rand.Source
|
||||
next int64
|
||||
i int8
|
||||
}
|
||||
|
||||
func newRandReader(src rand.Source) io.Reader {
|
||||
r := &randReader{
|
||||
src: src,
|
||||
next: src.Int63(),
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *randReader) Read(p []byte) (n int, err error) {
|
||||
next, i := r.next, r.i
|
||||
for n = 0; n < len(p); n++ {
|
||||
if i == 7 {
|
||||
next, i = r.src.Int63(), 0
|
||||
}
|
||||
p[n] = byte(next)
|
||||
next >>= 8
|
||||
i++
|
||||
}
|
||||
r.next, r.i = next, i
|
||||
return
|
||||
}
|
||||
|
||||
func TestServerSendTimeout(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
s.SetTimeout(time.Second)
|
||||
s.SetRetries(2)
|
||||
var serverErr error
|
||||
s.readHandler = func(filename string, rf io.ReaderFrom) error {
|
||||
r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
|
||||
_, serverErr = rf.ReadFrom(r)
|
||||
return serverErr
|
||||
}
|
||||
defer s.Shutdown()
|
||||
filename := "test-server-send-timeout"
|
||||
mode := "octet"
|
||||
readTransfer, err := c.Receive(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting read %s: %v", filename, err)
|
||||
}
|
||||
w := &slowWriter{
|
||||
n: 3,
|
||||
delay: 8 * time.Second,
|
||||
}
|
||||
_, _ = readTransfer.WriteTo(w)
|
||||
netErr, ok := serverErr.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("network error expected: %T", serverErr)
|
||||
}
|
||||
if !netErr.Timeout() {
|
||||
t.Fatalf("timout is expected: %v", serverErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerReceiveTimeout(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
s.SetTimeout(time.Second)
|
||||
s.SetRetries(2)
|
||||
var serverErr error
|
||||
s.writeHandler = func(filename string, wt io.WriterTo) error {
|
||||
buf := &bytes.Buffer{}
|
||||
_, serverErr = wt.WriteTo(buf)
|
||||
return serverErr
|
||||
}
|
||||
defer s.Shutdown()
|
||||
filename := "test-server-receive-timeout"
|
||||
mode := "octet"
|
||||
writeTransfer, err := c.Send(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting write %s: %v", filename, err)
|
||||
}
|
||||
r := &slowReader{
|
||||
r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
|
||||
n: 3,
|
||||
delay: 8 * time.Second,
|
||||
}
|
||||
_, _ = writeTransfer.ReadFrom(r)
|
||||
netErr, ok := serverErr.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("network error expected: %T", serverErr)
|
||||
}
|
||||
if !netErr.Timeout() {
|
||||
t.Fatalf("timout is expected: %v", serverErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReceiveTimeout(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
c.SetTimeout(time.Second)
|
||||
c.SetRetries(2)
|
||||
s.readHandler = func(filename string, rf io.ReaderFrom) error {
|
||||
r := &slowReader{
|
||||
r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000),
|
||||
n: 3,
|
||||
delay: 8 * time.Second,
|
||||
}
|
||||
_, err := rf.ReadFrom(r)
|
||||
return err
|
||||
}
|
||||
defer s.Shutdown()
|
||||
filename := "test-client-receive-timeout"
|
||||
mode := "octet"
|
||||
readTransfer, err := c.Receive(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting read %s: %v", filename, err)
|
||||
}
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = readTransfer.WriteTo(buf)
|
||||
netErr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("network error expected: %T", err)
|
||||
}
|
||||
if !netErr.Timeout() {
|
||||
t.Fatalf("timout is expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientSendTimeout(t *testing.T) {
|
||||
s, c := makeTestServer()
|
||||
c.SetTimeout(time.Second)
|
||||
c.SetRetries(2)
|
||||
s.writeHandler = func(filename string, wt io.WriterTo) error {
|
||||
w := &slowWriter{
|
||||
n: 3,
|
||||
delay: 8 * time.Second,
|
||||
}
|
||||
_, err := wt.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
defer s.Shutdown()
|
||||
filename := "test-client-send-timeout"
|
||||
mode := "octet"
|
||||
writeTransfer, err := c.Send(filename, mode)
|
||||
if err != nil {
|
||||
t.Fatalf("requesting write %s: %v", filename, err)
|
||||
}
|
||||
r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000)
|
||||
_, err = writeTransfer.ReadFrom(r)
|
||||
netErr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("network error expected: %T", err)
|
||||
}
|
||||
if !netErr.Timeout() {
|
||||
t.Fatalf("timout is expected: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type slowReader struct {
|
||||
r io.Reader
|
||||
n int64
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (r *slowReader) Read(p []byte) (n int, err error) {
|
||||
if r.n > 0 {
|
||||
r.n--
|
||||
return r.r.Read(p)
|
||||
}
|
||||
time.Sleep(r.delay)
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
type slowWriter struct {
|
||||
r io.Reader
|
||||
n int64
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (r *slowWriter) Write(p []byte) (n int, err error) {
|
||||
if r.n > 0 {
|
||||
r.n--
|
||||
return len(p), nil
|
||||
}
|
||||
time.Sleep(r.delay)
|
||||
return len(p), nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user