mac80211/minstrel_ht: use the new rate control API

Pass the rate selection table to mac80211 from minstrel_ht_update_stats.
Only rates for sample attempts are set in info->control.rates.

Signed-off-by: Felix Fietkau <nbd@openwrt.org>
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
diff --git a/net/mac80211/rc80211_minstrel_ht.c b/net/mac80211/rc80211_minstrel_ht.c
index a8e979e..5b2d301 100644
--- a/net/mac80211/rc80211_minstrel_ht.c
+++ b/net/mac80211/rc80211_minstrel_ht.c
@@ -126,6 +126,9 @@
 
 static u8 sample_table[SAMPLE_COLUMNS][MCS_GROUP_RATES];
 
+static void
+minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi);
+
 /*
  * Look up an MCS group index based on mac80211 rate information
  */
@@ -465,7 +468,7 @@
 	struct ieee80211_tx_rate *ar = info->status.rates;
 	struct minstrel_rate_stats *rate, *rate2;
 	struct minstrel_priv *mp = priv;
-	bool last;
+	bool last, update = false;
 	int i;
 
 	if (!msp->is_ht)
@@ -514,21 +517,29 @@
 	rate = minstrel_get_ratestats(mi, mi->max_tp_rate);
 	if (rate->attempts > 30 &&
 	    MINSTREL_FRAC(rate->success, rate->attempts) <
-	    MINSTREL_FRAC(20, 100))
+	    MINSTREL_FRAC(20, 100)) {
 		minstrel_downgrade_rate(mi, &mi->max_tp_rate, true);
+		update = true;
+	}
 
 	rate2 = minstrel_get_ratestats(mi, mi->max_tp_rate2);
 	if (rate2->attempts > 30 &&
 	    MINSTREL_FRAC(rate2->success, rate2->attempts) <
-	    MINSTREL_FRAC(20, 100))
+	    MINSTREL_FRAC(20, 100)) {
 		minstrel_downgrade_rate(mi, &mi->max_tp_rate2, false);
+		update = true;
+	}
 
 	if (time_after(jiffies, mi->stats_update + (mp->update_interval / 2 * HZ) / 1000)) {
+		update = true;
 		minstrel_ht_update_stats(mp, mi);
 		if (!(info->flags & IEEE80211_TX_CTL_AMPDU) &&
 		    mi->max_prob_rate / MCS_GROUP_RATES != MINSTREL_CCK_GROUP)
 			minstrel_aggr_check(sta, skb);
 	}
+
+	if (update)
+		minstrel_ht_update_rates(mp, mi);
 }
 
 static void
@@ -592,36 +603,71 @@
 
 static void
 minstrel_ht_set_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
-                     struct ieee80211_tx_rate *rate, int index,
-                     bool sample, bool rtscts)
+                     struct ieee80211_sta_rates *ratetbl, int offset, int index)
 {
 	const struct mcs_group *group = &minstrel_mcs_groups[index / MCS_GROUP_RATES];
 	struct minstrel_rate_stats *mr;
+	u8 idx;
+	u16 flags;
 
 	mr = minstrel_get_ratestats(mi, index);
 	if (!mr->retry_updated)
 		minstrel_calc_retransmit(mp, mi, index);
 
-	if (sample)
-		rate->count = 1;
-	else if (mr->probability < MINSTREL_FRAC(20, 100))
-		rate->count = 2;
-	else if (rtscts)
-		rate->count = mr->retry_count_rtscts;
-	else
-		rate->count = mr->retry_count;
-
-	rate->flags = 0;
-	if (rtscts)
-		rate->flags |= IEEE80211_TX_RC_USE_RTS_CTS;
-
-	if (index / MCS_GROUP_RATES == MINSTREL_CCK_GROUP) {
-		rate->idx = mp->cck_rates[index % ARRAY_SIZE(mp->cck_rates)];
-		return;
+	if (mr->probability < MINSTREL_FRAC(20, 100) || !mr->retry_count) {
+		ratetbl->rate[offset].count = 2;
+		ratetbl->rate[offset].count_rts = 2;
+		ratetbl->rate[offset].count_cts = 2;
+	} else {
+		ratetbl->rate[offset].count = mr->retry_count;
+		ratetbl->rate[offset].count_cts = mr->retry_count;
+		ratetbl->rate[offset].count_rts = mr->retry_count_rtscts;
 	}
 
-	rate->flags |= IEEE80211_TX_RC_MCS | group->flags;
-	rate->idx = index % MCS_GROUP_RATES + (group->streams - 1) * MCS_GROUP_RATES;
+	if (index / MCS_GROUP_RATES == MINSTREL_CCK_GROUP) {
+		idx = mp->cck_rates[index % ARRAY_SIZE(mp->cck_rates)];
+		flags = 0;
+	} else {
+		idx = index % MCS_GROUP_RATES +
+		      (group->streams - 1) * MCS_GROUP_RATES;
+		flags = IEEE80211_TX_RC_MCS | group->flags;
+	}
+
+	if (offset > 0) {
+		ratetbl->rate[offset].count = ratetbl->rate[offset].count_rts;
+		flags |= IEEE80211_TX_RC_USE_RTS_CTS;
+	}
+
+	ratetbl->rate[offset].idx = idx;
+	ratetbl->rate[offset].flags = flags;
+}
+
+static void
+minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi)
+{
+	struct ieee80211_sta_rates *rates;
+	int i = 0;
+
+	rates = kzalloc(sizeof(*rates), GFP_ATOMIC);
+	if (!rates)
+		return;
+
+	/* Start with max_tp_rate */
+	minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_tp_rate);
+
+	if (mp->hw->max_rates >= 3) {
+		/* At least 3 tx rates supported, use max_tp_rate2 next */
+		minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_tp_rate2);
+	}
+
+	if (mp->hw->max_rates >= 2) {
+		/*
+		 * At least 2 tx rates supported, use max_prob_rate next */
+		minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_prob_rate);
+	}
+
+	rates->rate[i].idx = -1;
+	rate_control_set_rates(mp->hw, mi->sta, rates);
 }
 
 static inline int
@@ -711,13 +757,13 @@
 minstrel_ht_get_rate(void *priv, struct ieee80211_sta *sta, void *priv_sta,
                      struct ieee80211_tx_rate_control *txrc)
 {
+	const struct mcs_group *sample_group;
 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
-	struct ieee80211_tx_rate *ar = info->status.rates;
+	struct ieee80211_tx_rate *rate = &info->status.rates[0];
 	struct minstrel_ht_sta_priv *msp = priv_sta;
 	struct minstrel_ht_sta *mi = &msp->ht;
 	struct minstrel_priv *mp = priv;
 	int sample_idx;
-	bool sample = false;
 
 	if (rate_control_send_low(sta, priv_sta, txrc))
 		return;
@@ -745,51 +791,6 @@
 	}
 #endif
 
-	if (sample_idx >= 0) {
-		sample = true;
-		minstrel_ht_set_rate(mp, mi, &ar[0], sample_idx,
-			true, false);
-		info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE;
-	} else {
-		minstrel_ht_set_rate(mp, mi, &ar[0], mi->max_tp_rate,
-			false, false);
-	}
-
-	if (mp->hw->max_rates >= 3) {
-		/*
-		 * At least 3 tx rates supported, use
-		 * sample_rate -> max_tp_rate -> max_prob_rate for sampling and
-		 * max_tp_rate -> max_tp_rate2 -> max_prob_rate by default.
-		 */
-		if (sample_idx >= 0)
-			minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_tp_rate,
-				false, false);
-		else
-			minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_tp_rate2,
-				false, true);
-
-		minstrel_ht_set_rate(mp, mi, &ar[2], mi->max_prob_rate,
-				     false, !sample);
-
-		ar[3].count = 0;
-		ar[3].idx = -1;
-	} else if (mp->hw->max_rates == 2) {
-		/*
-		 * Only 2 tx rates supported, use
-		 * sample_rate -> max_prob_rate for sampling and
-		 * max_tp_rate -> max_prob_rate by default.
-		 */
-		minstrel_ht_set_rate(mp, mi, &ar[1], mi->max_prob_rate,
-				     false, !sample);
-
-		ar[2].count = 0;
-		ar[2].idx = -1;
-	} else {
-		/* Not using MRR, only use the first rate */
-		ar[1].count = 0;
-		ar[1].idx = -1;
-	}
-
 	mi->total_packets++;
 
 	/* wraparound */
@@ -797,6 +798,16 @@
 		mi->total_packets = 0;
 		mi->sample_packets = 0;
 	}
+
+	if (sample_idx < 0)
+		return;
+
+	sample_group = &minstrel_mcs_groups[sample_idx / MCS_GROUP_RATES];
+	info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE;
+	rate->idx = sample_idx % MCS_GROUP_RATES +
+		    (sample_group->streams - 1) * MCS_GROUP_RATES;
+	rate->flags = IEEE80211_TX_RC_MCS | sample_group->flags;
+	rate->count = 1;
 }
 
 static void
@@ -846,6 +857,8 @@
 
 	msp->is_ht = true;
 	memset(mi, 0, sizeof(*mi));
+
+	mi->sta = sta;
 	mi->stats_update = jiffies;
 
 	ack_dur = ieee80211_frame_duration(sband->band, 10, 60, 1, 1);
@@ -907,8 +920,9 @@
 	if (!n_supported)
 		goto use_legacy;
 
-	/* init {mi,mi->groups[*]}->{max_tp_rate,max_tp_rate2,max_prob_rate} */
+	/* create an initial rate table with the lowest supported rates */
 	minstrel_ht_update_stats(mp, mi);
+	minstrel_ht_update_rates(mp, mi);
 
 	return;