vsock: add multi-transports support
This patch adds the support of multiple transports in the
VSOCK core.
With the multi-transports support, we can use vsock with nested VMs
(using also different hypervisors) loading both guest->host and
host->guest transports at the same time.
Major changes:
- vsock core module can be loaded regardless of the transports
- vsock_core_init() and vsock_core_exit() are renamed to
vsock_core_register() and vsock_core_unregister()
- vsock_core_register() has a feature parameter (H2G, G2H, DGRAM)
to identify which directions the transport can handle and if it's
support DGRAM (only vmci)
- each stream socket is assigned to a transport when the remote CID
is set (during the connect() or when we receive a connection request
on a listener socket).
The remote CID is used to decide which transport to use:
- remote CID <= VMADDR_CID_HOST will use guest->host transport;
- remote CID == local_cid (guest->host transport) will use guest->host
transport for loopback (host->guest transports don't support loopback);
- remote CID > VMADDR_CID_HOST will use host->guest transport;
- listener sockets are not bound to any transports since no transport
operations are done on it. In this way we can create a listener
socket, also if the transports are not loaded or with VMADDR_CID_ANY
to listen on all transports.
- DGRAM sockets are handled as before, since only the vmci_transport
provides this feature.
Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 8985d9d..5357714 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -130,7 +130,12 @@ static struct proto vsock_proto = {
#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
-static const struct vsock_transport *transport_single;
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
static DEFINE_MUTEX(vsock_register_mutex);
/**** UTILS ****/
@@ -182,7 +187,7 @@ static int vsock_auto_bind(struct vsock_sock *vsk)
return __vsock_bind(sk, &local_addr);
}
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
{
int i;
@@ -191,7 +196,6 @@ static int __init vsock_init_tables(void)
for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
INIT_LIST_HEAD(&vsock_connected_table[i]);
- return 0;
}
static void __vsock_insert_bound(struct list_head *list,
@@ -376,6 +380,68 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected)
}
EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ * - remote CID == local_cid (guest->host transport) will use guest->host
+ * transport for loopback (host->guest transports don't support loopback);
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ const struct vsock_transport *new_transport;
+ struct sock *sk = sk_vsock(vsk);
+ unsigned int remote_cid = vsk->remote_addr.svm_cid;
+
+ switch (sk->sk_type) {
+ case SOCK_DGRAM:
+ new_transport = transport_dgram;
+ break;
+ case SOCK_STREAM:
+ if (remote_cid <= VMADDR_CID_HOST ||
+ (transport_g2h &&
+ remote_cid == transport_g2h->get_local_cid()))
+ new_transport = transport_g2h;
+ else
+ new_transport = transport_h2g;
+ break;
+ default:
+ return -ESOCKTNOSUPPORT;
+ }
+
+ if (vsk->transport) {
+ if (vsk->transport == new_transport)
+ return 0;
+
+ vsk->transport->release(vsk);
+ vsk->transport->destruct(vsk);
+ }
+
+ if (!new_transport)
+ return -ENODEV;
+
+ vsk->transport = new_transport;
+
+ return vsk->transport->init(vsk, psk);
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+ if (transport_g2h && cid == transport_g2h->get_local_cid())
+ return true;
+
+ if (transport_h2g && cid == VMADDR_CID_HOST)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
static struct sock *vsock_dequeue_accept(struct sock *listener)
{
struct vsock_sock *vlistener;
@@ -414,6 +480,9 @@ static int vsock_send_shutdown(struct sock *sk, int mode)
{
struct vsock_sock *vsk = vsock_sk(sk);
+ if (!vsk->transport)
+ return -ENODEV;
+
return vsk->transport->shutdown(vsk, mode);
}
@@ -530,7 +599,6 @@ static int __vsock_bind_dgram(struct vsock_sock *vsk,
static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
{
struct vsock_sock *vsk = vsock_sk(sk);
- u32 cid;
int retval;
/* First ensure this socket isn't already bound. */
@@ -540,10 +608,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
/* Now bind to the provided address or select appropriate values if
* none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that
* like AF_INET prevents binding to a non-local IP address (in most
- * cases), we only allow binding to the local CID.
+ * cases), we only allow binding to a local CID.
*/
- cid = vsk->transport->get_local_cid();
- if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+ if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
return -EADDRNOTAVAIL;
switch (sk->sk_socket->type) {
@@ -592,7 +659,6 @@ static struct sock *__vsock_create(struct net *net,
sk->sk_type = type;
vsk = vsock_sk(sk);
- vsk->transport = transport_single;
vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
@@ -629,11 +695,6 @@ static struct sock *__vsock_create(struct net *net,
vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
}
- if (vsk->transport->init(vsk, psk) < 0) {
- sk_free(sk);
- return NULL;
- }
-
return sk;
}
@@ -649,7 +710,10 @@ static void __vsock_release(struct sock *sk, int level)
/* The release call is supposed to use lock_sock_nested()
* rather than lock_sock(), if a sock lock should be acquired.
*/
- vsk->transport->release(vsk);
+ if (vsk->transport)
+ vsk->transport->release(vsk);
+ else if (sk->sk_type == SOCK_STREAM)
+ vsock_remove_sock(vsk);
/* When "level" is SINGLE_DEPTH_NESTING, use the nested
* version to avoid the warning "possible recursive locking
@@ -677,7 +741,8 @@ static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
- vsk->transport->destruct(vsk);
+ if (vsk->transport)
+ vsk->transport->destruct(vsk);
/* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel.
@@ -894,7 +959,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
mask |= EPOLLIN | EPOLLRDNORM;
/* If there is something in the queue then we can read. */
- if (transport->stream_is_active(vsk) &&
+ if (transport && transport->stream_is_active(vsk) &&
!(sk->sk_shutdown & RCV_SHUTDOWN)) {
bool data_ready_now = false;
int ret = transport->notify_poll_in(
@@ -1144,7 +1209,6 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
err = 0;
sk = sock->sk;
vsk = vsock_sk(sk);
- transport = vsk->transport;
lock_sock(sk);
@@ -1172,19 +1236,26 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
goto out;
}
+ /* Set the remote address that we are connecting to. */
+ memcpy(&vsk->remote_addr, remote_addr,
+ sizeof(vsk->remote_addr));
+
+ err = vsock_assign_transport(vsk, NULL);
+ if (err)
+ goto out;
+
+ transport = vsk->transport;
+
/* The hypervisor and well-known contexts do not have socket
* endpoints.
*/
- if (!transport->stream_allow(remote_addr->svm_cid,
+ if (!transport ||
+ !transport->stream_allow(remote_addr->svm_cid,
remote_addr->svm_port)) {
err = -ENETUNREACH;
goto out;
}
- /* Set the remote address that we are connecting to. */
- memcpy(&vsk->remote_addr, remote_addr,
- sizeof(vsk->remote_addr));
-
err = vsock_auto_bind(vsk);
if (err)
goto out;
@@ -1584,7 +1655,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
goto out;
}
- if (sk->sk_state != TCP_ESTABLISHED ||
+ if (!transport || sk->sk_state != TCP_ESTABLISHED ||
!vsock_addr_bound(&vsk->local_addr)) {
err = -ENOTCONN;
goto out;
@@ -1710,7 +1781,7 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
lock_sock(sk);
- if (sk->sk_state != TCP_ESTABLISHED) {
+ if (!transport || sk->sk_state != TCP_ESTABLISHED) {
/* Recvmsg is supposed to return 0 if a peer performs an
* orderly shutdown. Differentiate between that case and when a
* peer has not connected or a local shutdown occured with the
@@ -1884,7 +1955,9 @@ static const struct proto_ops vsock_stream_ops = {
static int vsock_create(struct net *net, struct socket *sock,
int protocol, int kern)
{
+ struct vsock_sock *vsk;
struct sock *sk;
+ int ret;
if (!sock)
return -EINVAL;
@@ -1909,7 +1982,17 @@ static int vsock_create(struct net *net, struct socket *sock,
if (!sk)
return -ENOMEM;
- vsock_insert_unbound(vsock_sk(sk));
+ vsk = vsock_sk(sk);
+
+ if (sock->type == SOCK_DGRAM) {
+ ret = vsock_assign_transport(vsk, NULL);
+ if (ret < 0) {
+ sock_put(sk);
+ return ret;
+ }
+ }
+
+ vsock_insert_unbound(vsk);
return 0;
}
@@ -1924,11 +2007,20 @@ static long vsock_dev_do_ioctl(struct file *filp,
unsigned int cmd, void __user *ptr)
{
u32 __user *p = ptr;
+ u32 cid = VMADDR_CID_ANY;
int retval = 0;
switch (cmd) {
case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
- if (put_user(transport_single->get_local_cid(), p) != 0)
+ /* To be compatible with the VMCI behavior, we prioritize the
+ * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+ */
+ if (transport_g2h)
+ cid = transport_g2h->get_local_cid();
+ else if (transport_h2g)
+ cid = transport_h2g->get_local_cid();
+
+ if (put_user(cid, p) != 0)
retval = -EFAULT;
break;
@@ -1968,24 +2060,13 @@ static struct miscdevice vsock_device = {
.fops = &vsock_device_ops,
};
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
{
- int err = mutex_lock_interruptible(&vsock_register_mutex);
+ int err = 0;
- if (err)
- return err;
+ vsock_init_tables();
- if (transport_single) {
- err = -EBUSY;
- goto err_busy;
- }
-
- /* Transport must be the owner of the protocol so that it can't
- * unload while there are open sockets.
- */
- vsock_proto.owner = owner;
- transport_single = t;
-
+ vsock_proto.owner = THIS_MODULE;
vsock_device.minor = MISC_DYNAMIC_MINOR;
err = misc_register(&vsock_device);
if (err) {
@@ -2006,7 +2087,6 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
goto err_unregister_proto;
}
- mutex_unlock(&vsock_register_mutex);
return 0;
err_unregister_proto:
@@ -2014,28 +2094,15 @@ int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
err_deregister_misc:
misc_deregister(&vsock_device);
err_reset_transport:
- transport_single = NULL;
-err_busy:
- mutex_unlock(&vsock_register_mutex);
return err;
}
-EXPORT_SYMBOL_GPL(__vsock_core_init);
-void vsock_core_exit(void)
+static void __exit vsock_exit(void)
{
- mutex_lock(&vsock_register_mutex);
-
misc_deregister(&vsock_device);
sock_unregister(AF_VSOCK);
proto_unregister(&vsock_proto);
-
- /* We do not want the assignment below re-ordered. */
- mb();
- transport_single = NULL;
-
- mutex_unlock(&vsock_register_mutex);
}
-EXPORT_SYMBOL_GPL(vsock_core_exit);
const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
{
@@ -2043,12 +2110,70 @@ const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_core_get_transport);
-static void __exit vsock_exit(void)
+int vsock_core_register(const struct vsock_transport *t, int features)
{
- /* Do nothing. This function makes this module removable. */
-}
+ const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+ int err = mutex_lock_interruptible(&vsock_register_mutex);
-module_init(vsock_init_tables);
+ if (err)
+ return err;
+
+ t_h2g = transport_h2g;
+ t_g2h = transport_g2h;
+ t_dgram = transport_dgram;
+
+ if (features & VSOCK_TRANSPORT_F_H2G) {
+ if (t_h2g) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_h2g = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_G2H) {
+ if (t_g2h) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_g2h = t;
+ }
+
+ if (features & VSOCK_TRANSPORT_F_DGRAM) {
+ if (t_dgram) {
+ err = -EBUSY;
+ goto err_busy;
+ }
+ t_dgram = t;
+ }
+
+ transport_h2g = t_h2g;
+ transport_g2h = t_g2h;
+ transport_dgram = t_dgram;
+
+err_busy:
+ mutex_unlock(&vsock_register_mutex);
+ return err;
+}
+EXPORT_SYMBOL_GPL(vsock_core_register);
+
+void vsock_core_unregister(const struct vsock_transport *t)
+{
+ mutex_lock(&vsock_register_mutex);
+
+ if (transport_h2g == t)
+ transport_h2g = NULL;
+
+ if (transport_g2h == t)
+ transport_g2h = NULL;
+
+ if (transport_dgram == t)
+ transport_dgram = NULL;
+
+ mutex_unlock(&vsock_register_mutex);
+}
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
+
+module_init(vsock_init);
module_exit(vsock_exit);
MODULE_AUTHOR("VMware, Inc.");