mac80211: IEEE 802.11 Extended Key ID support

Add support for Extended Key ID as defined in IEEE 802.11-2016.

 - Implement the nl80211 API for Extended Key ID
 - Extend mac80211 API to allow drivers to support Extended Key ID
 - Enable Extended Key ID by default for drivers only supporting SW
   crypto (e.g. mac80211_hwsim)
 - Allow unicast Tx usage to be supressed (IEEE80211_KEY_FLAG_NO_AUTO_TX)
 - Select the decryption key based on the MPDU keyid
 - Enforce existing assumptions in the code that rekeys don't change the
   cipher

Signed-off-by: Alexander Wetzel <alexander@wetzel-home.de>
[remove module parameter]
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c
index 09dd1c2..14bbb7e 100644
--- a/net/mac80211/cfg.c
+++ b/net/mac80211/cfg.c
@@ -351,6 +351,36 @@ static int ieee80211_set_noack_map(struct wiphy *wiphy,
 	return 0;
 }
 
+static int ieee80211_set_tx(struct ieee80211_sub_if_data *sdata,
+			    const u8 *mac_addr, u8 key_idx)
+{
+	struct ieee80211_local *local = sdata->local;
+	struct ieee80211_key *key;
+	struct sta_info *sta;
+	int ret = -EINVAL;
+
+	if (!wiphy_ext_feature_isset(local->hw.wiphy,
+				     NL80211_EXT_FEATURE_EXT_KEY_ID))
+		return -EINVAL;
+
+	sta = sta_info_get_bss(sdata, mac_addr);
+
+	if (!sta)
+		return -EINVAL;
+
+	if (sta->ptk_idx == key_idx)
+		return 0;
+
+	mutex_lock(&local->key_mtx);
+	key = key_mtx_dereference(local, sta->ptk[key_idx]);
+
+	if (key && key->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX)
+		ret = ieee80211_set_tx_key(key);
+
+	mutex_unlock(&local->key_mtx);
+	return ret;
+}
+
 static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
 			     u8 key_idx, bool pairwise, const u8 *mac_addr,
 			     struct key_params *params)
@@ -365,6 +395,9 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
 	if (!ieee80211_sdata_running(sdata))
 		return -ENETDOWN;
 
+	if (pairwise && params->mode == NL80211_KEY_SET_TX)
+		return ieee80211_set_tx(sdata, mac_addr, key_idx);
+
 	/* reject WEP and TKIP keys if WEP failed to initialize */
 	switch (params->cipher) {
 	case WLAN_CIPHER_SUITE_WEP40:
@@ -396,6 +429,9 @@ static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev,
 	if (pairwise)
 		key->conf.flags |= IEEE80211_KEY_FLAG_PAIRWISE;
 
+	if (params->mode == NL80211_KEY_NO_TX)
+		key->conf.flags |= IEEE80211_KEY_FLAG_NO_AUTO_TX;
+
 	mutex_lock(&local->sta_mtx);
 
 	if (mac_addr) {
diff --git a/net/mac80211/debugfs.c b/net/mac80211/debugfs.c
index 2d43bc1..aa6f23e 100644
--- a/net/mac80211/debugfs.c
+++ b/net/mac80211/debugfs.c
@@ -221,6 +221,7 @@ static const char *hw_flag_names[] = {
 	FLAG(TX_STATUS_NO_AMPDU_LEN),
 	FLAG(SUPPORTS_MULTI_BSSID),
 	FLAG(SUPPORTS_ONLY_HE_MULTI_BSSID),
+	FLAG(EXT_KEY_ID_NATIVE),
 #undef FLAG
 };
 
diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index c5708f8..32094e2 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -1269,7 +1269,7 @@ struct ieee80211_local {
 
 	/*
 	 * Key mutex, protects sdata's key_list and sta_info's
-	 * key pointers (write access, they're RCU.)
+	 * key pointers and ptk_idx (write access, they're RCU.)
 	 */
 	struct mutex key_mtx;
 
diff --git a/net/mac80211/key.c b/net/mac80211/key.c
index 41b8db3..42d52cd 100644
--- a/net/mac80211/key.c
+++ b/net/mac80211/key.c
@@ -265,9 +265,24 @@ static void ieee80211_key_disable_hw_accel(struct ieee80211_key *key)
 			  sta ? sta->sta.addr : bcast_addr, ret);
 }
 
+int ieee80211_set_tx_key(struct ieee80211_key *key)
+{
+	struct sta_info *sta = key->sta;
+	struct ieee80211_local *local = key->local;
+	struct ieee80211_key *old;
+
+	assert_key_lock(local);
+
+	old = key_mtx_dereference(local, sta->ptk[sta->ptk_idx]);
+	sta->ptk_idx = key->conf.keyidx;
+	ieee80211_check_fast_xmit(sta);
+
+	return 0;
+}
+
 static int ieee80211_hw_key_replace(struct ieee80211_key *old_key,
 				    struct ieee80211_key *new_key,
-				    bool ptk0rekey)
+				    bool pairwise)
 {
 	struct ieee80211_sub_if_data *sdata;
 	struct ieee80211_local *local;
@@ -284,8 +299,9 @@ static int ieee80211_hw_key_replace(struct ieee80211_key *old_key,
 	assert_key_lock(old_key->local);
 	sta = old_key->sta;
 
-	/* PTK only using key ID 0 needs special handling on rekey */
-	if (new_key && sta && ptk0rekey) {
+	/* Unicast rekey without Extended Key ID needs special handling */
+	if (new_key && sta && pairwise &&
+	    rcu_access_pointer(sta->ptk[sta->ptk_idx]) == old_key) {
 		local = old_key->local;
 		sdata = old_key->sdata;
 
@@ -401,10 +417,6 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
 
 	if (old) {
 		idx = old->conf.keyidx;
-		/* TODO: proper implement and test "Extended Key ID for
-		 * Individually Addressed Frames" from IEEE 802.11-2016.
-		 * Till then always assume only key ID 0 is used for
-		 * pairwise keys.*/
 		ret = ieee80211_hw_key_replace(old, new, pairwise);
 	} else {
 		/* new must be provided in case old is not */
@@ -421,15 +433,20 @@ static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata,
 	if (sta) {
 		if (pairwise) {
 			rcu_assign_pointer(sta->ptk[idx], new);
-			sta->ptk_idx = idx;
-			if (new) {
+			if (new &&
+			    !(new->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX)) {
+				sta->ptk_idx = idx;
 				clear_sta_flag(sta, WLAN_STA_BLOCK_BA);
 				ieee80211_check_fast_xmit(sta);
 			}
 		} else {
 			rcu_assign_pointer(sta->gtk[idx], new);
 		}
-		if (new)
+		/* Only needed for transition from no key -> key.
+		 * Still triggers unnecessary when using Extended Key ID
+		 * and installing the second key ID the first time.
+		 */
+		if (new && !old)
 			ieee80211_check_fast_rx(sta);
 	} else {
 		defunikey = old &&
@@ -745,16 +762,34 @@ int ieee80211_key_link(struct ieee80211_key *key,
 	 * can cause warnings to appear.
 	 */
 	bool delay_tailroom = sdata->vif.type == NL80211_IFTYPE_STATION;
-	int ret;
+	int ret = -EOPNOTSUPP;
 
 	mutex_lock(&sdata->local->key_mtx);
 
-	if (sta && pairwise)
+	if (sta && pairwise) {
+		struct ieee80211_key *alt_key;
+
 		old_key = key_mtx_dereference(sdata->local, sta->ptk[idx]);
-	else if (sta)
+		alt_key = key_mtx_dereference(sdata->local, sta->ptk[idx ^ 1]);
+
+		/* The rekey code assumes that the old and new key are using
+		 * the same cipher. Enforce the assumption for pairwise keys.
+		 */
+		if (key &&
+		    ((alt_key && alt_key->conf.cipher != key->conf.cipher) ||
+		     (old_key && old_key->conf.cipher != key->conf.cipher)))
+			goto out;
+	} else if (sta) {
 		old_key = key_mtx_dereference(sdata->local, sta->gtk[idx]);
-	else
+	} else {
 		old_key = key_mtx_dereference(sdata->local, sdata->keys[idx]);
+	}
+
+	/* Non-pairwise keys must also not switch the cipher on rekey */
+	if (!pairwise) {
+		if (key && old_key && old_key->conf.cipher != key->conf.cipher)
+			goto out;
+	}
 
 	/*
 	 * Silently accept key re-installation without really installing the
diff --git a/net/mac80211/key.h b/net/mac80211/key.h
index ebdb80b..f06fbd0 100644
--- a/net/mac80211/key.h
+++ b/net/mac80211/key.h
@@ -18,6 +18,7 @@
 
 #define NUM_DEFAULT_KEYS 4
 #define NUM_DEFAULT_MGMT_KEYS 2
+#define INVALID_PTK_KEYIDX 2 /* Keyidx always pointing to a NULL key for PTK */
 
 struct ieee80211_local;
 struct ieee80211_sub_if_data;
@@ -146,6 +147,7 @@ ieee80211_key_alloc(u32 cipher, int idx, size_t key_len,
 int ieee80211_key_link(struct ieee80211_key *key,
 		       struct ieee80211_sub_if_data *sdata,
 		       struct sta_info *sta);
+int ieee80211_set_tx_key(struct ieee80211_key *key);
 void ieee80211_key_free(struct ieee80211_key *key, bool delay_tailroom);
 void ieee80211_key_free_unused(struct ieee80211_key *key);
 void ieee80211_set_default_key(struct ieee80211_sub_if_data *sdata, int idx,
diff --git a/net/mac80211/main.c b/net/mac80211/main.c
index 800e676..5d6b930 100644
--- a/net/mac80211/main.c
+++ b/net/mac80211/main.c
@@ -1051,6 +1051,11 @@ int ieee80211_register_hw(struct ieee80211_hw *hw)
 		}
 	}
 
+	if (!local->ops->set_key ||
+	    ieee80211_hw_check(&local->hw, EXT_KEY_ID_NATIVE))
+		wiphy_ext_feature_set(local->hw.wiphy,
+				      NL80211_EXT_FEATURE_EXT_KEY_ID);
+
 	/*
 	 * Calculate scan IE length -- we need this to alloc
 	 * memory and to subtract from the driver limit. It
diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c
index 7f8d934..4a03c18 100644
--- a/net/mac80211/rx.c
+++ b/net/mac80211/rx.c
@@ -1005,23 +1005,43 @@ static int ieee80211_get_mmie_keyidx(struct sk_buff *skb)
 	return -1;
 }
 
-static int ieee80211_get_cs_keyid(const struct ieee80211_cipher_scheme *cs,
-				  struct sk_buff *skb)
+static int ieee80211_get_keyid(struct sk_buff *skb,
+			       const struct ieee80211_cipher_scheme *cs)
 {
 	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
 	__le16 fc;
 	int hdrlen;
+	int minlen;
+	u8 key_idx_off;
+	u8 key_idx_shift;
 	u8 keyid;
 
 	fc = hdr->frame_control;
 	hdrlen = ieee80211_hdrlen(fc);
 
-	if (skb->len < hdrlen + cs->hdr_len)
+	if (cs) {
+		minlen = hdrlen + cs->hdr_len;
+		key_idx_off = hdrlen + cs->key_idx_off;
+		key_idx_shift = cs->key_idx_shift;
+	} else {
+		/* WEP, TKIP, CCMP and GCMP */
+		minlen = hdrlen + IEEE80211_WEP_IV_LEN;
+		key_idx_off = hdrlen + 3;
+		key_idx_shift = 6;
+	}
+
+	if (unlikely(skb->len < minlen))
 		return -EINVAL;
 
-	skb_copy_bits(skb, hdrlen + cs->key_idx_off, &keyid, 1);
-	keyid &= cs->key_idx_mask;
-	keyid >>= cs->key_idx_shift;
+	skb_copy_bits(skb, key_idx_off, &keyid, 1);
+
+	if (cs)
+		keyid &= cs->key_idx_mask;
+	keyid >>= key_idx_shift;
+
+	/* cs could use more than the usual two bits for the keyid */
+	if (unlikely(keyid >= NUM_DEFAULT_KEYS))
+		return -EINVAL;
 
 	return keyid;
 }
@@ -1852,9 +1872,9 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
 	struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb);
 	struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
 	int keyidx;
-	int hdrlen;
 	ieee80211_rx_result result = RX_DROP_UNUSABLE;
 	struct ieee80211_key *sta_ptk = NULL;
+	struct ieee80211_key *ptk_idx = NULL;
 	int mmie_keyidx = -1;
 	__le16 fc;
 	const struct ieee80211_cipher_scheme *cs = NULL;
@@ -1892,21 +1912,24 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
 
 	if (rx->sta) {
 		int keyid = rx->sta->ptk_idx;
+		sta_ptk = rcu_dereference(rx->sta->ptk[keyid]);
 
-		if (ieee80211_has_protected(fc) && rx->sta->cipher_scheme) {
+		if (ieee80211_has_protected(fc)) {
 			cs = rx->sta->cipher_scheme;
-			keyid = ieee80211_get_cs_keyid(cs, rx->skb);
+			keyid = ieee80211_get_keyid(rx->skb, cs);
+
 			if (unlikely(keyid < 0))
 				return RX_DROP_UNUSABLE;
+
+			ptk_idx = rcu_dereference(rx->sta->ptk[keyid]);
 		}
-		sta_ptk = rcu_dereference(rx->sta->ptk[keyid]);
 	}
 
 	if (!ieee80211_has_protected(fc))
 		mmie_keyidx = ieee80211_get_mmie_keyidx(rx->skb);
 
 	if (!is_multicast_ether_addr(hdr->addr1) && sta_ptk) {
-		rx->key = sta_ptk;
+		rx->key = ptk_idx ? ptk_idx : sta_ptk;
 		if ((status->flag & RX_FLAG_DECRYPTED) &&
 		    (status->flag & RX_FLAG_IV_STRIPPED))
 			return RX_CONTINUE;
@@ -1966,8 +1989,6 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
 		}
 		return RX_CONTINUE;
 	} else {
-		u8 keyid;
-
 		/*
 		 * The device doesn't give us the IV so we won't be
 		 * able to look up the key. That's ok though, we
@@ -1981,23 +2002,10 @@ ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx)
 		    (status->flag & RX_FLAG_IV_STRIPPED))
 			return RX_CONTINUE;
 
-		hdrlen = ieee80211_hdrlen(fc);
+		keyidx = ieee80211_get_keyid(rx->skb, cs);
 
-		if (cs) {
-			keyidx = ieee80211_get_cs_keyid(cs, rx->skb);
-
-			if (unlikely(keyidx < 0))
-				return RX_DROP_UNUSABLE;
-		} else {
-			if (rx->skb->len < 8 + hdrlen)
-				return RX_DROP_UNUSABLE; /* TODO: count this? */
-			/*
-			 * no need to call ieee80211_wep_get_keyidx,
-			 * it verifies a bunch of things we've done already
-			 */
-			skb_copy_bits(rx->skb, hdrlen + 3, &keyid, 1);
-			keyidx = keyid >> 6;
-		}
+		if (unlikely(keyidx < 0))
+			return RX_DROP_UNUSABLE;
 
 		/* check per-station GTK first, if multicast packet */
 		if (is_multicast_ether_addr(hdr->addr1) && rx->sta)
@@ -4042,12 +4050,8 @@ void ieee80211_check_fast_rx(struct sta_info *sta)
 		case WLAN_CIPHER_SUITE_GCMP_256:
 			break;
 		default:
-			/* we also don't want to deal with WEP or cipher scheme
-			 * since those require looking up the key idx in the
-			 * frame, rather than assuming the PTK is used
-			 * (we need to revisit this once we implement the real
-			 * PTK index, which is now valid in the spec, but we
-			 * haven't implemented that part yet)
+			/* We also don't want to deal with
+			 * WEP or cipher scheme.
 			 */
 			goto clear_rcu;
 		}
diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c
index a81e127..a4932ee 100644
--- a/net/mac80211/sta_info.c
+++ b/net/mac80211/sta_info.c
@@ -347,6 +347,15 @@ struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
 	sta->sta.max_rx_aggregation_subframes =
 		local->hw.max_rx_aggregation_subframes;
 
+	/* Extended Key ID needs to install keys for keyid 0 and 1 Rx-only.
+	 * The Tx path starts to use a key as soon as the key slot ptk_idx
+	 * references to is not NULL. To not use the initial Rx-only key
+	 * prematurely for Tx initialize ptk_idx to an impossible PTK keyid
+	 * which always will refer to a NULL key.
+	 */
+	BUILD_BUG_ON(ARRAY_SIZE(sta->ptk) <= INVALID_PTK_KEYIDX);
+	sta->ptk_idx = INVALID_PTK_KEYIDX;
+
 	sta->local = local;
 	sta->sdata = sdata;
 	sta->rx_stats.last_rx = jiffies;
diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c
index a3c6053..c49fd1e 100644
--- a/net/mac80211/tx.c
+++ b/net/mac80211/tx.c
@@ -3001,23 +3001,15 @@ void ieee80211_check_fast_xmit(struct sta_info *sta)
 		switch (build.key->conf.cipher) {
 		case WLAN_CIPHER_SUITE_CCMP:
 		case WLAN_CIPHER_SUITE_CCMP_256:
-			/* add fixed key ID */
-			if (gen_iv) {
-				(build.hdr + build.hdr_len)[3] =
-					0x20 | (build.key->conf.keyidx << 6);
+			if (gen_iv)
 				build.pn_offs = build.hdr_len;
-			}
 			if (gen_iv || iv_spc)
 				build.hdr_len += IEEE80211_CCMP_HDR_LEN;
 			break;
 		case WLAN_CIPHER_SUITE_GCMP:
 		case WLAN_CIPHER_SUITE_GCMP_256:
-			/* add fixed key ID */
-			if (gen_iv) {
-				(build.hdr + build.hdr_len)[3] =
-					0x20 | (build.key->conf.keyidx << 6);
+			if (gen_iv)
 				build.pn_offs = build.hdr_len;
-			}
 			if (gen_iv || iv_spc)
 				build.hdr_len += IEEE80211_GCMP_HDR_LEN;
 			break;
@@ -3388,6 +3380,7 @@ static void ieee80211_xmit_fast_finish(struct ieee80211_sub_if_data *sdata,
 			pn = atomic64_inc_return(&key->conf.tx_pn);
 			crypto_hdr[0] = pn;
 			crypto_hdr[1] = pn >> 8;
+			crypto_hdr[3] = 0x20 | (key->conf.keyidx << 6);
 			crypto_hdr[4] = pn >> 16;
 			crypto_hdr[5] = pn >> 24;
 			crypto_hdr[6] = pn >> 32;