mirror of
https://github.com/rancher/os.git
synced 2025-07-08 04:18:38 +00:00
200 lines
4.7 KiB
Go
200 lines
4.7 KiB
Go
package tftp
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// NewServer creates TFTP server. It requires two functions to handle
|
|
// read and write requests.
|
|
// In case nil is provided for read or write handler the respective
|
|
// operation is disabled.
|
|
func NewServer(readHandler func(filename string, rf io.ReaderFrom) error,
|
|
writeHandler func(filename string, wt io.WriterTo) error) *Server {
|
|
return &Server{
|
|
readHandler: readHandler,
|
|
writeHandler: writeHandler,
|
|
timeout: defaultTimeout,
|
|
retries: defaultRetries,
|
|
}
|
|
}
|
|
|
|
type Server struct {
|
|
readHandler func(filename string, rf io.ReaderFrom) error
|
|
writeHandler func(filename string, wt io.WriterTo) error
|
|
backoff backoffFunc
|
|
conn *net.UDPConn
|
|
quit chan chan struct{}
|
|
wg sync.WaitGroup
|
|
timeout time.Duration
|
|
retries int
|
|
}
|
|
|
|
// SetTimeout sets maximum time server waits for single network
|
|
// round-trip to succeed.
|
|
// Default is 5 seconds.
|
|
func (s *Server) SetTimeout(t time.Duration) {
|
|
if t <= 0 {
|
|
s.timeout = defaultTimeout
|
|
} else {
|
|
s.timeout = t
|
|
}
|
|
}
|
|
|
|
// SetRetries sets maximum number of attempts server made to transmit a
|
|
// packet.
|
|
// Default is 5 attempts.
|
|
func (s *Server) SetRetries(count int) {
|
|
if count < 1 {
|
|
s.retries = defaultRetries
|
|
} else {
|
|
s.retries = count
|
|
}
|
|
}
|
|
|
|
// SetBackoff sets a user provided function that is called to provide a
|
|
// backoff duration prior to retransmitting an unacknowledged packet.
|
|
func (s *Server) SetBackoff(h backoffFunc) {
|
|
s.backoff = h
|
|
}
|
|
|
|
// ListenAndServe binds to address provided and start the server.
|
|
// ListenAndServe returns when Shutdown is called.
|
|
func (s *Server) ListenAndServe(addr string) error {
|
|
a, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn, err := net.ListenUDP("udp", a)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.Serve(conn)
|
|
return nil
|
|
}
|
|
|
|
// Serve starts server provided already opened UDP connecton. It is
|
|
// useful for the case when you want to run server in separate goroutine
|
|
// but still want to be able to handle any errors opening connection.
|
|
// Serve returns when Shutdown is called or connection is closed.
|
|
func (s *Server) Serve(conn *net.UDPConn) {
|
|
s.conn = conn
|
|
s.quit = make(chan chan struct{})
|
|
for {
|
|
select {
|
|
case q := <-s.quit:
|
|
q <- struct{}{}
|
|
return
|
|
default:
|
|
err := s.processRequest(s.conn)
|
|
if err != nil {
|
|
// TODO: add logging handler
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Shutdown make server stop listening for new requests, allows
|
|
// server to finish outstanding transfers and stops server.
|
|
func (s *Server) Shutdown() {
|
|
s.conn.Close()
|
|
q := make(chan struct{})
|
|
s.quit <- q
|
|
<-q
|
|
s.wg.Wait()
|
|
}
|
|
|
|
func (s *Server) processRequest(conn *net.UDPConn) error {
|
|
var buffer []byte
|
|
buffer = make([]byte, datagramLength)
|
|
n, remoteAddr, err := conn.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
return fmt.Errorf("reading UDP: %v", err)
|
|
}
|
|
p, err := parsePacket(buffer[:n])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
switch p := p.(type) {
|
|
case pWRQ:
|
|
filename, mode, opts, err := unpackRQ(p)
|
|
if err != nil {
|
|
return fmt.Errorf("unpack WRQ: %v", err)
|
|
}
|
|
//fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
|
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("open transmission: %v", err)
|
|
}
|
|
wt := &receiver{
|
|
send: make([]byte, datagramLength),
|
|
receive: make([]byte, datagramLength),
|
|
conn: conn,
|
|
retry: &backoff{handler: s.backoff},
|
|
timeout: s.timeout,
|
|
retries: s.retries,
|
|
addr: remoteAddr,
|
|
mode: mode,
|
|
opts: opts,
|
|
}
|
|
s.wg.Add(1)
|
|
go func() {
|
|
if s.writeHandler != nil {
|
|
err := s.writeHandler(filename, wt)
|
|
if err != nil {
|
|
wt.abort(err)
|
|
} else {
|
|
wt.terminate()
|
|
wt.conn.Close()
|
|
}
|
|
} else {
|
|
wt.abort(fmt.Errorf("server does not support write requests"))
|
|
}
|
|
s.wg.Done()
|
|
}()
|
|
case pRRQ:
|
|
filename, mode, opts, err := unpackRQ(p)
|
|
if err != nil {
|
|
return fmt.Errorf("unpack RRQ: %v", err)
|
|
}
|
|
//fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts)
|
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rf := &sender{
|
|
send: make([]byte, datagramLength),
|
|
receive: make([]byte, datagramLength),
|
|
tid: remoteAddr.Port,
|
|
conn: conn,
|
|
retry: &backoff{handler: s.backoff},
|
|
timeout: s.timeout,
|
|
retries: s.retries,
|
|
addr: remoteAddr,
|
|
mode: mode,
|
|
opts: opts,
|
|
}
|
|
s.wg.Add(1)
|
|
go func() {
|
|
if s.readHandler != nil {
|
|
err := s.readHandler(filename, rf)
|
|
if err != nil {
|
|
rf.abort(err)
|
|
}
|
|
} else {
|
|
rf.abort(fmt.Errorf("server does not support read requests"))
|
|
}
|
|
s.wg.Done()
|
|
}()
|
|
default:
|
|
return fmt.Errorf("unexpected %T", p)
|
|
}
|
|
return nil
|
|
}
|