usb: cdc-wdm: Fix race between write and disconnect

Unify mutexes to fix a race between write and disconnect
and shift the test for disconnection to always report it.

Signed-off-by: Oliver Neukum <neukum@b1-systems.de>
Signed-off-by: Greg Kroah-Hartman <gregkh@suse.de>

diff --git a/drivers/usb/class/cdc-wdm.c b/drivers/usb/class/cdc-wdm.c
index 18aafcb..cf1c5fb 100644
--- a/drivers/usb/class/cdc-wdm.c
+++ b/drivers/usb/class/cdc-wdm.c
@@ -87,9 +87,7 @@
 	int			count;
 	dma_addr_t		shandle;
 	dma_addr_t		ihandle;
-	struct mutex		wlock;
-	struct mutex		rlock;
-	struct mutex		plock;
+	struct mutex		lock;
 	wait_queue_head_t	wait;
 	struct work_struct	rxwork;
 	int			werr;
@@ -305,14 +303,38 @@
 	if (we < 0)
 		return -EIO;
 
-	r = mutex_lock_interruptible(&desc->wlock); /* concurrent writes */
-	rv = -ERESTARTSYS;
-	if (r)
+	desc->outbuf = buf = kmalloc(count, GFP_KERNEL);
+	if (!buf) {
+		rv = -ENOMEM;
 		goto outnl;
+	}
+
+	r = copy_from_user(buf, buffer, count);
+	if (r > 0) {
+		kfree(buf);
+		rv = -EFAULT;
+		goto outnl;
+	}
+
+	/* concurrent writes and disconnect */
+	r = mutex_lock_interruptible(&desc->lock);
+	rv = -ERESTARTSYS;
+	if (r) {
+		kfree(buf);
+		goto outnl;
+	}
+
+	if (test_bit(WDM_DISCONNECTING, &desc->flags)) {
+		kfree(buf);
+		rv = -ENODEV;
+		goto outnp;
+	}
 
 	r = usb_autopm_get_interface(desc->intf);
-	if (r < 0)
+	if (r < 0) {
+		kfree(buf);
 		goto outnp;
+	}
 
 	if (!file->f_flags && O_NONBLOCK)
 		r = wait_event_interruptible(desc->wait, !test_bit(WDM_IN_USE,
@@ -320,24 +342,8 @@
 	else
 		if (test_bit(WDM_IN_USE, &desc->flags))
 			r = -EAGAIN;
-	if (r < 0)
-		goto out;
-
-	if (test_bit(WDM_DISCONNECTING, &desc->flags)) {
-		rv = -ENODEV;
-		goto out;
-	}
-
-	desc->outbuf = buf = kmalloc(count, GFP_KERNEL);
-	if (!buf) {
-		rv = -ENOMEM;
-		goto out;
-	}
-
-	r = copy_from_user(buf, buffer, count);
-	if (r > 0) {
+	if (r < 0) {
 		kfree(buf);
-		rv = -EFAULT;
 		goto out;
 	}
 
@@ -374,7 +380,7 @@
 out:
 	usb_autopm_put_interface(desc->intf);
 outnp:
-	mutex_unlock(&desc->wlock);
+	mutex_unlock(&desc->lock);
 outnl:
 	return rv < 0 ? rv : count;
 }
@@ -387,7 +393,7 @@
 	struct wdm_device *desc = file->private_data;
 
 
-	rv = mutex_lock_interruptible(&desc->rlock); /*concurrent reads */
+	rv = mutex_lock_interruptible(&desc->lock); /*concurrent reads */
 	if (rv < 0)
 		return -ERESTARTSYS;
 
@@ -465,7 +471,7 @@
 	rv = cntr;
 
 err:
-	mutex_unlock(&desc->rlock);
+	mutex_unlock(&desc->lock);
 	if (rv < 0 && rv != -EAGAIN)
 		dev_err(&desc->intf->dev, "wdm_read: exit error\n");
 	return rv;
@@ -533,7 +539,7 @@
 	}
 	intf->needs_remote_wakeup = 1;
 
-	mutex_lock(&desc->plock);
+	mutex_lock(&desc->lock);
 	if (!desc->count++) {
 		rv = usb_submit_urb(desc->validity, GFP_KERNEL);
 		if (rv < 0) {
@@ -544,7 +550,7 @@
 	} else {
 		rv = 0;
 	}
-	mutex_unlock(&desc->plock);
+	mutex_unlock(&desc->lock);
 	usb_autopm_put_interface(desc->intf);
 out:
 	mutex_unlock(&wdm_mutex);
@@ -556,9 +562,9 @@
 	struct wdm_device *desc = file->private_data;
 
 	mutex_lock(&wdm_mutex);
-	mutex_lock(&desc->plock);
+	mutex_lock(&desc->lock);
 	desc->count--;
-	mutex_unlock(&desc->plock);
+	mutex_unlock(&desc->lock);
 
 	if (!desc->count) {
 		dev_dbg(&desc->intf->dev, "wdm_release: cleanup");
@@ -655,9 +661,7 @@
 	desc = kzalloc(sizeof(struct wdm_device), GFP_KERNEL);
 	if (!desc)
 		goto out;
-	mutex_init(&desc->wlock);
-	mutex_init(&desc->rlock);
-	mutex_init(&desc->plock);
+	mutex_init(&desc->lock);
 	spin_lock_init(&desc->iuspin);
 	init_waitqueue_head(&desc->wait);
 	desc->wMaxCommand = maxcom;
@@ -772,7 +776,9 @@
 	clear_bit(WDM_IN_USE, &desc->flags);
 	spin_unlock_irqrestore(&desc->iuspin, flags);
 	cancel_work_sync(&desc->rxwork);
+	mutex_lock(&desc->lock);
 	kill_urbs(desc);
+	mutex_unlock(&desc->lock);
 	wake_up_all(&desc->wait);
 	if (!desc->count)
 		cleanup(desc);
@@ -786,7 +792,7 @@
 
 	dev_dbg(&desc->intf->dev, "wdm%d_suspend\n", intf->minor);
 
-	mutex_lock(&desc->plock);
+	mutex_lock(&desc->lock);
 #ifdef CONFIG_PM
 	if ((message.event & PM_EVENT_AUTO) &&
 			test_bit(WDM_IN_USE, &desc->flags)) {
@@ -798,7 +804,7 @@
 #ifdef CONFIG_PM
 	}
 #endif
-	mutex_unlock(&desc->plock);
+	mutex_unlock(&desc->lock);
 
 	return rv;
 }
@@ -821,9 +827,9 @@
 	int rv;
 
 	dev_dbg(&desc->intf->dev, "wdm%d_resume\n", intf->minor);
-	mutex_lock(&desc->plock);
+	mutex_lock(&desc->lock);
 	rv = recover_from_urb_loss(desc);
-	mutex_unlock(&desc->plock);
+	mutex_unlock(&desc->lock);
 	return rv;
 }
 
@@ -831,7 +837,7 @@
 {
 	struct wdm_device *desc = usb_get_intfdata(intf);
 
-	mutex_lock(&desc->plock);
+	mutex_lock(&desc->lock);
 	return 0;
 }
 
@@ -841,7 +847,7 @@
 	int rv;
 
 	rv = recover_from_urb_loss(desc);
-	mutex_unlock(&desc->plock);
+	mutex_unlock(&desc->lock);
 	return 0;
 }