IB/mlx5: Support IB_WR_REG_SIG_MR

This patch implements IB_WR_REG_SIG_MR posted by the user.

Baisically this WR involves 3 WQEs in order to prepare and properly
register the signature layout:

1. post UMR WR to register the sig_mr in one of two possible ways:
    * In case the user registered a single MR for data so the UMR data segment
      consists of:
      - single klm (data MR) passed by the user
      - BSF with signature attributes requested by the user.
    * In case the user registered 2 MRs, one for data and one for protection,
      the UMR consists of:
      - strided block format which includes data and protection MRs and
        their repetitive block format.
      - BSF with signature attributes requested by the user.

2. post SET_PSV in order to set the memory domain initial
   signature parameters passed by the user.
   SET_PSV is not signaled and solicited CQE.

3. post SET_PSV in order to set the wire domain initial
   signature parameters passed by the user.
   SET_PSV is not signaled and solicited CQE.

* After this compound WR we place a small fence for next WR to come.

This patch also introduces some helper functions to set the BSF correctly
and determining the signature format selectors.

Signed-off-by: Sagi Grimberg <sagig@mellanox.com>
Signed-off-by: Roland Dreier <roland@purestorage.com>
diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c
index 1dbadbf..67e7998 100644
--- a/drivers/infiniband/hw/mlx5/qp.c
+++ b/drivers/infiniband/hw/mlx5/qp.c
@@ -1777,6 +1777,26 @@
 	return cpu_to_be64(result);
 }
 
+static __be64 sig_mkey_mask(void)
+{
+	u64 result;
+
+	result = MLX5_MKEY_MASK_LEN		|
+		MLX5_MKEY_MASK_PAGE_SIZE	|
+		MLX5_MKEY_MASK_START_ADDR	|
+		MLX5_MKEY_MASK_EN_RINVAL	|
+		MLX5_MKEY_MASK_KEY		|
+		MLX5_MKEY_MASK_LR		|
+		MLX5_MKEY_MASK_LW		|
+		MLX5_MKEY_MASK_RR		|
+		MLX5_MKEY_MASK_RW		|
+		MLX5_MKEY_MASK_SMALL_FENCE	|
+		MLX5_MKEY_MASK_FREE		|
+		MLX5_MKEY_MASK_BSF_EN;
+
+	return cpu_to_be64(result);
+}
+
 static void set_frwr_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
 				 struct ib_send_wr *wr, int li)
 {
@@ -1961,6 +1981,339 @@
 	return 0;
 }
 
+static u16 prot_field_size(enum ib_signature_type type)
+{
+	switch (type) {
+	case IB_SIG_TYPE_T10_DIF:
+		return MLX5_DIF_SIZE;
+	default:
+		return 0;
+	}
+}
+
+static u8 bs_selector(int block_size)
+{
+	switch (block_size) {
+	case 512:	    return 0x1;
+	case 520:	    return 0x2;
+	case 4096:	    return 0x3;
+	case 4160:	    return 0x4;
+	case 1073741824:    return 0x5;
+	default:	    return 0;
+	}
+}
+
+static int format_selector(struct ib_sig_attrs *attr,
+			   struct ib_sig_domain *domain,
+			   int *selector)
+{
+
+#define FORMAT_DIF_NONE		0
+#define FORMAT_DIF_CRC_INC	8
+#define FORMAT_DIF_CRC_NO_INC	12
+#define FORMAT_DIF_CSUM_INC	13
+#define FORMAT_DIF_CSUM_NO_INC	14
+
+	switch (domain->sig.dif.type) {
+	case IB_T10DIF_NONE:
+		/* No DIF */
+		*selector = FORMAT_DIF_NONE;
+		break;
+	case IB_T10DIF_TYPE1: /* Fall through */
+	case IB_T10DIF_TYPE2:
+		switch (domain->sig.dif.bg_type) {
+		case IB_T10DIF_CRC:
+			*selector = FORMAT_DIF_CRC_INC;
+			break;
+		case IB_T10DIF_CSUM:
+			*selector = FORMAT_DIF_CSUM_INC;
+			break;
+		default:
+			return 1;
+		}
+		break;
+	case IB_T10DIF_TYPE3:
+		switch (domain->sig.dif.bg_type) {
+		case IB_T10DIF_CRC:
+			*selector = domain->sig.dif.type3_inc_reftag ?
+					   FORMAT_DIF_CRC_INC :
+					   FORMAT_DIF_CRC_NO_INC;
+			break;
+		case IB_T10DIF_CSUM:
+			*selector = domain->sig.dif.type3_inc_reftag ?
+					   FORMAT_DIF_CSUM_INC :
+					   FORMAT_DIF_CSUM_NO_INC;
+			break;
+		default:
+			return 1;
+		}
+		break;
+	default:
+		return 1;
+	}
+
+	return 0;
+}
+
+static int mlx5_set_bsf(struct ib_mr *sig_mr,
+			struct ib_sig_attrs *sig_attrs,
+			struct mlx5_bsf *bsf, u32 data_size)
+{
+	struct mlx5_core_sig_ctx *msig = to_mmr(sig_mr)->sig;
+	struct mlx5_bsf_basic *basic = &bsf->basic;
+	struct ib_sig_domain *mem = &sig_attrs->mem;
+	struct ib_sig_domain *wire = &sig_attrs->wire;
+	int ret, selector;
+
+	switch (sig_attrs->mem.sig_type) {
+	case IB_SIG_TYPE_T10_DIF:
+		if (sig_attrs->wire.sig_type != IB_SIG_TYPE_T10_DIF)
+			return -EINVAL;
+
+		/* Input domain check byte mask */
+		basic->check_byte_mask = sig_attrs->check_mask;
+		if (mem->sig.dif.pi_interval == wire->sig.dif.pi_interval &&
+		    mem->sig.dif.type == wire->sig.dif.type) {
+			/* Same block structure */
+			basic->bsf_size_sbs = 1 << 4;
+			if (mem->sig.dif.bg_type == wire->sig.dif.bg_type)
+				basic->wire.copy_byte_mask = 0xff;
+			else
+				basic->wire.copy_byte_mask = 0x3f;
+		} else
+			basic->wire.bs_selector = bs_selector(wire->sig.dif.pi_interval);
+
+		basic->mem.bs_selector = bs_selector(mem->sig.dif.pi_interval);
+		basic->raw_data_size = cpu_to_be32(data_size);
+
+		ret = format_selector(sig_attrs, mem, &selector);
+		if (ret)
+			return -EINVAL;
+		basic->m_bfs_psv = cpu_to_be32(selector << 24 |
+					       msig->psv_memory.psv_idx);
+
+		ret = format_selector(sig_attrs, wire, &selector);
+		if (ret)
+			return -EINVAL;
+		basic->w_bfs_psv = cpu_to_be32(selector << 24 |
+					       msig->psv_wire.psv_idx);
+		break;
+
+	default:
+		return -EINVAL;
+	}
+
+	return 0;
+}
+
+static int set_sig_data_segment(struct ib_send_wr *wr, struct mlx5_ib_qp *qp,
+				void **seg, int *size)
+{
+	struct ib_sig_attrs *sig_attrs = wr->wr.sig_handover.sig_attrs;
+	struct ib_mr *sig_mr = wr->wr.sig_handover.sig_mr;
+	struct mlx5_bsf *bsf;
+	u32 data_len = wr->sg_list->length;
+	u32 data_key = wr->sg_list->lkey;
+	u64 data_va = wr->sg_list->addr;
+	int ret;
+	int wqe_size;
+
+	if (!wr->wr.sig_handover.prot) {
+		/**
+		 * Source domain doesn't contain signature information
+		 * So need construct:
+		 *                  ------------------
+		 *                 |     data_klm     |
+		 *                  ------------------
+		 *                 |       BSF        |
+		 *                  ------------------
+		 **/
+		struct mlx5_klm *data_klm = *seg;
+
+		data_klm->bcount = cpu_to_be32(data_len);
+		data_klm->key = cpu_to_be32(data_key);
+		data_klm->va = cpu_to_be64(data_va);
+		wqe_size = ALIGN(sizeof(*data_klm), 64);
+	} else {
+		/**
+		 * Source domain contains signature information
+		 * So need construct a strided block format:
+		 *               ---------------------------
+		 *              |     stride_block_ctrl     |
+		 *               ---------------------------
+		 *              |          data_klm         |
+		 *               ---------------------------
+		 *              |          prot_klm         |
+		 *               ---------------------------
+		 *              |             BSF           |
+		 *               ---------------------------
+		 **/
+		struct mlx5_stride_block_ctrl_seg *sblock_ctrl;
+		struct mlx5_stride_block_entry *data_sentry;
+		struct mlx5_stride_block_entry *prot_sentry;
+		u32 prot_key = wr->wr.sig_handover.prot->lkey;
+		u64 prot_va = wr->wr.sig_handover.prot->addr;
+		u16 block_size = sig_attrs->mem.sig.dif.pi_interval;
+		int prot_size;
+
+		sblock_ctrl = *seg;
+		data_sentry = (void *)sblock_ctrl + sizeof(*sblock_ctrl);
+		prot_sentry = (void *)data_sentry + sizeof(*data_sentry);
+
+		prot_size = prot_field_size(sig_attrs->mem.sig_type);
+		if (!prot_size) {
+			pr_err("Bad block size given: %u\n", block_size);
+			return -EINVAL;
+		}
+		sblock_ctrl->bcount_per_cycle = cpu_to_be32(block_size +
+							    prot_size);
+		sblock_ctrl->op = cpu_to_be32(MLX5_STRIDE_BLOCK_OP);
+		sblock_ctrl->repeat_count = cpu_to_be32(data_len / block_size);
+		sblock_ctrl->num_entries = cpu_to_be16(2);
+
+		data_sentry->bcount = cpu_to_be16(block_size);
+		data_sentry->key = cpu_to_be32(data_key);
+		data_sentry->va = cpu_to_be64(data_va);
+		prot_sentry->bcount = cpu_to_be16(prot_size);
+		prot_sentry->key = cpu_to_be32(prot_key);
+
+		if (prot_key == data_key && prot_va == data_va) {
+			/**
+			 * The data and protection are interleaved
+			 * in a single memory region
+			 **/
+			prot_sentry->va = cpu_to_be64(data_va + block_size);
+			prot_sentry->stride = cpu_to_be16(block_size + prot_size);
+			data_sentry->stride = prot_sentry->stride;
+		} else {
+			/* The data and protection are two different buffers */
+			prot_sentry->va = cpu_to_be64(prot_va);
+			data_sentry->stride = cpu_to_be16(block_size);
+			prot_sentry->stride = cpu_to_be16(prot_size);
+		}
+		wqe_size = ALIGN(sizeof(*sblock_ctrl) + sizeof(*data_sentry) +
+				 sizeof(*prot_sentry), 64);
+	}
+
+	*seg += wqe_size;
+	*size += wqe_size / 16;
+	if (unlikely((*seg == qp->sq.qend)))
+		*seg = mlx5_get_send_wqe(qp, 0);
+
+	bsf = *seg;
+	ret = mlx5_set_bsf(sig_mr, sig_attrs, bsf, data_len);
+	if (ret)
+		return -EINVAL;
+
+	*seg += sizeof(*bsf);
+	*size += sizeof(*bsf) / 16;
+	if (unlikely((*seg == qp->sq.qend)))
+		*seg = mlx5_get_send_wqe(qp, 0);
+
+	return 0;
+}
+
+static void set_sig_mkey_segment(struct mlx5_mkey_seg *seg,
+				 struct ib_send_wr *wr, u32 nelements,
+				 u32 length, u32 pdn)
+{
+	struct ib_mr *sig_mr = wr->wr.sig_handover.sig_mr;
+	u32 sig_key = sig_mr->rkey;
+
+	memset(seg, 0, sizeof(*seg));
+
+	seg->flags = get_umr_flags(wr->wr.sig_handover.access_flags) |
+				   MLX5_ACCESS_MODE_KLM;
+	seg->qpn_mkey7_0 = cpu_to_be32((sig_key & 0xff) | 0xffffff00);
+	seg->flags_pd = cpu_to_be32(MLX5_MKEY_REMOTE_INVAL |
+				    MLX5_MKEY_BSF_EN | pdn);
+	seg->len = cpu_to_be64(length);
+	seg->xlt_oct_size = cpu_to_be32(be16_to_cpu(get_klm_octo(nelements)));
+	seg->bsfs_octo_size = cpu_to_be32(MLX5_MKEY_BSF_OCTO_SIZE);
+}
+
+static void set_sig_umr_segment(struct mlx5_wqe_umr_ctrl_seg *umr,
+				struct ib_send_wr *wr, u32 nelements)
+{
+	memset(umr, 0, sizeof(*umr));
+
+	umr->flags = MLX5_FLAGS_INLINE | MLX5_FLAGS_CHECK_FREE;
+	umr->klm_octowords = get_klm_octo(nelements);
+	umr->bsf_octowords = cpu_to_be16(MLX5_MKEY_BSF_OCTO_SIZE);
+	umr->mkey_mask = sig_mkey_mask();
+}
+
+
+static int set_sig_umr_wr(struct ib_send_wr *wr, struct mlx5_ib_qp *qp,
+			  void **seg, int *size)
+{
+	struct mlx5_ib_mr *sig_mr = to_mmr(wr->wr.sig_handover.sig_mr);
+	u32 pdn = get_pd(qp)->pdn;
+	u32 klm_oct_size;
+	int region_len, ret;
+
+	if (unlikely(wr->num_sge != 1) ||
+	    unlikely(wr->wr.sig_handover.access_flags &
+		     IB_ACCESS_REMOTE_ATOMIC) ||
+	    unlikely(!sig_mr->sig) || unlikely(!qp->signature_en))
+		return -EINVAL;
+
+	/* length of the protected region, data + protection */
+	region_len = wr->sg_list->length;
+	if (wr->wr.sig_handover.prot)
+		region_len += wr->wr.sig_handover.prot->length;
+
+	/**
+	 * KLM octoword size - if protection was provided
+	 * then we use strided block format (3 octowords),
+	 * else we use single KLM (1 octoword)
+	 **/
+	klm_oct_size = wr->wr.sig_handover.prot ? 3 : 1;
+
+	set_sig_umr_segment(*seg, wr, klm_oct_size);
+	*seg += sizeof(struct mlx5_wqe_umr_ctrl_seg);
+	*size += sizeof(struct mlx5_wqe_umr_ctrl_seg) / 16;
+	if (unlikely((*seg == qp->sq.qend)))
+		*seg = mlx5_get_send_wqe(qp, 0);
+
+	set_sig_mkey_segment(*seg, wr, klm_oct_size, region_len, pdn);
+	*seg += sizeof(struct mlx5_mkey_seg);
+	*size += sizeof(struct mlx5_mkey_seg) / 16;
+	if (unlikely((*seg == qp->sq.qend)))
+		*seg = mlx5_get_send_wqe(qp, 0);
+
+	ret = set_sig_data_segment(wr, qp, seg, size);
+	if (ret)
+		return ret;
+
+	return 0;
+}
+
+static int set_psv_wr(struct ib_sig_domain *domain,
+		      u32 psv_idx, void **seg, int *size)
+{
+	struct mlx5_seg_set_psv *psv_seg = *seg;
+
+	memset(psv_seg, 0, sizeof(*psv_seg));
+	psv_seg->psv_num = cpu_to_be32(psv_idx);
+	switch (domain->sig_type) {
+	case IB_SIG_TYPE_T10_DIF:
+		psv_seg->transient_sig = cpu_to_be32(domain->sig.dif.bg << 16 |
+						     domain->sig.dif.app_tag);
+		psv_seg->ref_tag = cpu_to_be32(domain->sig.dif.ref_tag);
+
+		*seg += sizeof(*psv_seg);
+		*size += sizeof(*psv_seg) / 16;
+		break;
+
+	default:
+		pr_err("Bad signature type given.\n");
+		return 1;
+	}
+
+	return 0;
+}
+
 static int set_frwr_li_wr(void **seg, struct ib_send_wr *wr, int *size,
 			  struct mlx5_core_dev *mdev, struct mlx5_ib_pd *pd, struct mlx5_ib_qp *qp)
 {
@@ -2108,6 +2461,7 @@
 	struct mlx5_ib_dev *dev = to_mdev(ibqp->device);
 	struct mlx5_core_dev *mdev = &dev->mdev;
 	struct mlx5_ib_qp *qp = to_mqp(ibqp);
+	struct mlx5_ib_mr *mr;
 	struct mlx5_wqe_data_seg *dpseg;
 	struct mlx5_wqe_xrc_seg *xrc;
 	struct mlx5_bf *bf = qp->bf;
@@ -2203,6 +2557,73 @@
 				num_sge = 0;
 				break;
 
+			case IB_WR_REG_SIG_MR:
+				qp->sq.wr_data[idx] = IB_WR_REG_SIG_MR;
+				mr = to_mmr(wr->wr.sig_handover.sig_mr);
+
+				ctrl->imm = cpu_to_be32(mr->ibmr.rkey);
+				err = set_sig_umr_wr(wr, qp, &seg, &size);
+				if (err) {
+					mlx5_ib_warn(dev, "\n");
+					*bad_wr = wr;
+					goto out;
+				}
+
+				finish_wqe(qp, ctrl, size, idx, wr->wr_id,
+					   nreq, get_fence(fence, wr),
+					   next_fence, MLX5_OPCODE_UMR);
+				/*
+				 * SET_PSV WQEs are not signaled and solicited
+				 * on error
+				 */
+				wr->send_flags &= ~IB_SEND_SIGNALED;
+				wr->send_flags |= IB_SEND_SOLICITED;
+				err = begin_wqe(qp, &seg, &ctrl, wr,
+						&idx, &size, nreq);
+				if (err) {
+					mlx5_ib_warn(dev, "\n");
+					err = -ENOMEM;
+					*bad_wr = wr;
+					goto out;
+				}
+
+				err = set_psv_wr(&wr->wr.sig_handover.sig_attrs->mem,
+						 mr->sig->psv_memory.psv_idx, &seg,
+						 &size);
+				if (err) {
+					mlx5_ib_warn(dev, "\n");
+					*bad_wr = wr;
+					goto out;
+				}
+
+				finish_wqe(qp, ctrl, size, idx, wr->wr_id,
+					   nreq, get_fence(fence, wr),
+					   next_fence, MLX5_OPCODE_SET_PSV);
+				err = begin_wqe(qp, &seg, &ctrl, wr,
+						&idx, &size, nreq);
+				if (err) {
+					mlx5_ib_warn(dev, "\n");
+					err = -ENOMEM;
+					*bad_wr = wr;
+					goto out;
+				}
+
+				next_fence = MLX5_FENCE_MODE_INITIATOR_SMALL;
+				err = set_psv_wr(&wr->wr.sig_handover.sig_attrs->wire,
+						 mr->sig->psv_wire.psv_idx, &seg,
+						 &size);
+				if (err) {
+					mlx5_ib_warn(dev, "\n");
+					*bad_wr = wr;
+					goto out;
+				}
+
+				finish_wqe(qp, ctrl, size, idx, wr->wr_id,
+					   nreq, get_fence(fence, wr),
+					   next_fence, MLX5_OPCODE_SET_PSV);
+				num_sge = 0;
+				goto skip_psv;
+
 			default:
 				break;
 			}
@@ -2286,6 +2707,7 @@
 		finish_wqe(qp, ctrl, size, idx, wr->wr_id, nreq,
 			   get_fence(fence, wr), next_fence,
 			   mlx5_ib_opcode[wr->opcode]);
+skip_psv:
 		if (0)
 			dump_wqe(qp, idx, size);
 	}