nc-vsock: Add support for Hyper-V sockets

Also tidy up some of the coding style to be more Linux kernel style
which most of the code already was.

Signed-off-by: Rolf Neugebauer <rolf.neugebauer@docker.com>
This commit is contained in:
Rolf Neugebauer 2016-04-18 17:17:32 +01:00
parent df350345cd
commit 59814bc752
3 changed files with 196 additions and 41 deletions

View File

@ -1,6 +1,6 @@
FROM alpine:3.3
RUN apk update && apk upgrade && apk add alpine-sdk
RUN apk update && apk upgrade && apk add alpine-sdk util-linux-dev
RUN mkdir -p /nc-vsock
WORKDIR /nc-vsock

View File

@ -8,7 +8,7 @@ all: Dockerfile $(DEPS)
chmod 755 nc-vsock
nc-vsock: $(DEPS)
gcc -Wall -Werror -o nc-vsock nc-vsock.c
gcc -Wall -Werror -o nc-vsock nc-vsock.c -luuid
clean:
rm -f nc-vsock

View File

@ -9,16 +9,56 @@
#include <sys/socket.h>
#include <sys/select.h>
#include <netdb.h>
#include <uuid/uuid.h>
#include "include/uapi/linux/vm_sockets.h"
#define MODE_READ 1 /* From the vsock */
#define MODE_WRITE 2 /* To the vsock */
#define MODE_RDWR (MODE_READ|MODE_WRITE)
/*
* Hyper-V Sockets headerfile pull in too much other stuff. Replicate
* the bits we need here.
*/
#ifndef AF_HYPERV
#define AF_HYPERV 42
#endif
struct sockaddr_hv {
unsigned short shv_family; /* Address family */
unsigned short reserved; /* Must be Zero */
uuid_t shv_vm_id; /* Not used. Must be Zero. */
uuid_t shv_service_id; /* Service ID */
};
UUID_DEFINE(SHV_VMID_GUEST,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0);
#define SHV_PROTO_RAW 1
/*
* MSFT's GUIDs are a bonkers mix of native and big endian byte
* order. The uuid library uses RFC 4122, which is always big endian.
* The Linux kernel uuid.h actually looks more like it should be
* called guid.h. We use the uuid library for ease of parsing/printing
* and then this function to convert between UUID and GUID.
* https://en.wikipedia.org/wiki/Globally_unique_identifier
*/
static void uuid2guid(uuid_t u)
{
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
char t;
t = u[0]; u[0] = u[3]; u[3] = t;
t = u[1]; u[1] = u[2]; u[2] = t;
t = u[4]; u[4] = u[5]; u[5] = t;
t = u[6]; u[6] = u[7]; u[7] = t;
#endif
}
static int parse_cid(const char *cid_str)
{
char *end = NULL;
long cid = strtol(cid_str, &end, 10);
if (cid_str != end && *end == '\0') {
return cid;
} else {
@ -31,6 +71,7 @@ static int parse_port(const char *port_str)
{
char *end = NULL;
long port = strtol(port_str, &end, 10);
if (port_str != end && *end == '\0') {
return port;
} else {
@ -49,10 +90,12 @@ static int vsock_listen(const char *port_str)
};
struct sockaddr_vm sa_client;
socklen_t socklen_client = sizeof(sa_client);
int port = parse_port(port_str);
if (port < 0) {
int port;
int ret;
port = parse_port(port_str);
if (port < 0)
return -1;
}
sa_listen.svm_port = port;
@ -62,26 +105,87 @@ static int vsock_listen(const char *port_str)
return -1;
}
if (bind(listen_fd, (struct sockaddr*)&sa_listen, sizeof(sa_listen)) != 0) {
ret = bind(listen_fd, (struct sockaddr*)&sa_listen, sizeof(sa_listen));
if (ret != 0) {
perror("bind");
close(listen_fd);
return -1;
}
if (listen(listen_fd, 1) != 0) {
ret = listen(listen_fd, 1);
if (ret != 0) {
perror("listen");
close(listen_fd);
return -1;
}
client_fd = accept(listen_fd, (struct sockaddr*)&sa_client, &socklen_client);
client_fd = accept(listen_fd,
(struct sockaddr*)&sa_client, &socklen_client);
if (client_fd < 0) {
perror("accept");
close(listen_fd);
return -1;
}
fprintf(stderr, "Connection from cid %u port %u...\n", sa_client.svm_cid, sa_client.svm_port);
fprintf(stderr, "Connection from cid %u port %u...\n",
sa_client.svm_cid, sa_client.svm_port);
close(listen_fd);
return client_fd;
}
static int hvsock_listen(const char *port_str)
{
int listen_fd;
int client_fd;
struct sockaddr_hv sa_listen = {
.shv_family = AF_HYPERV,
.reserved = 0,
};
struct sockaddr_hv sa_client;
socklen_t socklen_client = sizeof(sa_client);
char vm_str[128], svc_str[128];
int ret;
uuid_copy(sa_listen.shv_vm_id, SHV_VMID_GUEST);
ret = uuid_parse(port_str, sa_listen.shv_service_id);
if (ret != 0)
return -1;
uuid2guid(sa_listen.shv_service_id);
listen_fd = socket(AF_HYPERV, SOCK_STREAM, SHV_PROTO_RAW);
if (listen_fd < 0) {
perror("socket");
return -1;
}
ret = bind(listen_fd, (struct sockaddr*)&sa_listen, sizeof(sa_listen));
if (ret != 0) {
perror("bind");
close(listen_fd);
return -1;
}
ret = listen(listen_fd, 1);
if (ret != 0) {
perror("listen");
close(listen_fd);
return -1;
}
client_fd = accept(listen_fd,
(struct sockaddr*)&sa_client, &socklen_client);
if (client_fd < 0) {
perror("accept");
close(listen_fd);
return -1;
}
uuid_unparse(sa_client.shv_vm_id, vm_str);
uuid_unparse(sa_client.shv_service_id, svc_str);
fprintf(stderr, "Connection from %s port %s...\n", vm_str, svc_str);
close(listen_fd);
return client_fd;
@ -105,13 +209,15 @@ static int tcp_connect(const char *node, const char *service)
}
for (addrinfo = res; addrinfo; addrinfo = addrinfo->ai_next) {
fd = socket(addrinfo->ai_family, addrinfo->ai_socktype, addrinfo->ai_protocol);
fd = socket(addrinfo->ai_family,
addrinfo->ai_socktype, addrinfo->ai_protocol);
if (fd < 0) {
perror("socket");
continue;
}
if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0) {
ret = connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen);
if (ret != 0) {
perror("connect");
close(fd);
continue;
@ -132,17 +238,18 @@ static int vsock_connect(const char *cid_str, const char *port_str)
struct sockaddr_vm sa = {
.svm_family = AF_VSOCK,
};
int ret;
cid = parse_cid(cid_str);
if (cid < 0) {
if (cid < 0)
return -1;
}
sa.svm_cid = cid;
port = parse_port(port_str);
if (port < 0) {
if (port < 0)
return -1;
}
sa.svm_port = port;
fd = socket(AF_VSOCK, SOCK_STREAM, 0);
@ -151,7 +258,47 @@ static int vsock_connect(const char *cid_str, const char *port_str)
return -1;
}
if (connect(fd, (struct sockaddr*)&sa, sizeof(sa)) != 0) {
ret = connect(fd, (struct sockaddr*)&sa, sizeof(sa));
if (ret != 0) {
perror("connect");
close(fd);
return -1;
}
return fd;
}
static int hvsock_connect(const char *vm_str, const char *svc_str)
{
int fd;
int ret;
struct sockaddr_hv sa = {
.shv_family = AF_HYPERV,
.reserved = 0,
};
ret = uuid_parse(vm_str, sa.shv_vm_id);
if (ret != 0) {
fprintf(stderr, "VM GUID parse error: %s\n", vm_str);
return -1;
}
uuid2guid(sa.shv_vm_id);
ret = uuid_parse(svc_str, sa.shv_service_id);
if (ret != 0) {
fprintf(stderr, "Service GUID parse error: %s\n", svc_str);
return -1;
}
uuid2guid(sa.shv_service_id);
fd = socket(AF_HYPERV, SOCK_STREAM, SHV_PROTO_RAW);
if (fd < 0) {
perror("socket");
return -1;
}
ret = connect(fd, (struct sockaddr*)&sa, sizeof(sa));
if (ret != 0) {
perror("connect");
close(fd);
return -1;
@ -166,10 +313,12 @@ static int get_fds(int argc, char **argv, int fds[2])
fds[1] = -1;
if (argc >= 3 && strcmp(argv[1], "-l") == 0) {
fds[1] = vsock_listen(argv[2]);
if (fds[1] < 0) {
if (strstr(argv[2], "-"))
fds[1] = hvsock_listen(argv[2]);
else
fds[1] = vsock_listen(argv[2]);
if (fds[1] < 0)
return -1;
}
if (argc == 6 && strcmp(argv[3], "-t") == 0) {
fds[0] = tcp_connect(argv[4], argv[5]);
@ -179,10 +328,12 @@ static int get_fds(int argc, char **argv, int fds[2])
}
return 0;
} else if (argc == 3) {
fds[1] = vsock_connect(argv[1], argv[2]);
if (fds[1] < 0) {
if (strstr(argv[1], "-") || strstr(argv[2], "-"))
fds[1] = hvsock_connect(argv[1], argv[2]);
else
fds[1] = vsock_connect(argv[1], argv[2]);
if (fds[1] < 0)
return -1;
}
return 0;
} else {
fprintf(stderr, "usage: %s [-r|-w] [-l <port> [-t <dst> <dstport>] | <cid> <port>]\n", argv[0]);
@ -202,9 +353,8 @@ static void set_nonblock(int fd, bool enable)
}
flags = ret & ~O_NONBLOCK;
if (enable) {
if (enable)
flags |= O_NONBLOCK;
}
fcntl(fd, F_SETFL, flags);
}
@ -215,18 +365,20 @@ static int xfer_data(int in_fd, int out_fd)
char *send_ptr = buf;
ssize_t nbytes;
ssize_t remaining;
int ret;
if (out_fd == STDIN_FILENO) out_fd = STDOUT_FILENO;
nbytes = read(in_fd, buf, sizeof(buf));
if (nbytes < 0) {
if (nbytes < 0)
return -1;
}
if (nbytes == 0) {
int rc;
if (out_fd == STDOUT_FILENO) return 0;
rc = shutdown(out_fd, SHUT_WR);
if (rc == 0) return 0;
if (out_fd == STDOUT_FILENO)
return 0;
ret = shutdown(out_fd, SHUT_WR);
if (ret == 0)
return 0;
perror("shutdown");
return -1;
}
@ -234,11 +386,10 @@ static int xfer_data(int in_fd, int out_fd)
remaining = nbytes;
while (remaining > 0) {
nbytes = write(out_fd, send_ptr, remaining);
if (nbytes < 0 && errno == EAGAIN) {
if (nbytes < 0 && errno == EAGAIN)
nbytes = 0;
} else if (nbytes <= 0) {
else if (nbytes <= 0)
return -1;
}
if (remaining > nbytes) {
/* Wait for fd to become writeable again */
@ -246,7 +397,9 @@ static int xfer_data(int in_fd, int out_fd)
fd_set wfds;
FD_ZERO(&wfds);
FD_SET(out_fd, &wfds);
if (select(out_fd + 1, NULL, &wfds, NULL, NULL) < 0) {
ret = select(out_fd + 1, NULL,
&wfds, NULL, NULL);
if (ret < 0) {
if (errno == EINTR) {
continue;
} else {
@ -255,9 +408,8 @@ static int xfer_data(int in_fd, int out_fd)
}
}
if (FD_ISSET(out_fd, &wfds)) {
if (FD_ISSET(out_fd, &wfds))
break;
}
}
}
@ -273,6 +425,7 @@ static void main_loop(int fds[2], int mode)
int nfds = fds[fds[0] > fds[1] ? 0 : 1] + 1;
/* Which fd's are readable */
bool rfd0 = !!(mode&MODE_WRITE), rfd1 = !!(mode&MODE_READ);
int ret;
set_nonblock(fds[0], true);
set_nonblock(fds[1], true);
@ -282,10 +435,13 @@ static void main_loop(int fds[2], int mode)
return;
FD_ZERO(&rfds);
if (rfd0) FD_SET(fds[0], &rfds);
if (rfd1) FD_SET(fds[1], &rfds);
if (rfd0)
FD_SET(fds[0], &rfds);
if (rfd1)
FD_SET(fds[1], &rfds);
if (select(nfds, &rfds, NULL, NULL, NULL) < 0) {
ret = select(nfds, &rfds, NULL, NULL, NULL);
if (ret < 0) {
if (errno == EINTR) {
continue;
} else {
@ -327,9 +483,8 @@ int main(int argc, char **argv)
}
}
if (get_fds(argc, argv, fds) < 0) {
if (get_fds(argc, argv, fds) < 0)
return EXIT_FAILURE;
}
main_loop(fds, mode);
return EXIT_SUCCESS;