KVM: x86: Add a common TSC scaling function

VMX and SVM calculate the TSC scaling ratio in a similar logic, so this
patch generalizes it to a common TSC scaling function.

Signed-off-by: Haozhong Zhang <haozhong.zhang@intel.com>
[Inline the multiplication and shift steps into mul_u64_u64_shr.  Remove
 BUG_ON.  - Paolo]
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index f3354bd..52d1419 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1238,6 +1238,8 @@
 void kvm_define_shared_msr(unsigned index, u32 msr);
 int kvm_set_shared_msr(unsigned index, u64 val, u64 mask);
 
+u64 kvm_scale_tsc(struct kvm_vcpu *vcpu, u64 tsc);
+
 unsigned long kvm_get_linear_rip(struct kvm_vcpu *vcpu);
 bool kvm_is_linear_rip(struct kvm_vcpu *vcpu, unsigned long linear_rip);
 
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index 9c92e6f..65f4f19 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -212,7 +212,6 @@
 static int nested_svm_vmexit(struct vcpu_svm *svm);
 static int nested_svm_check_exception(struct vcpu_svm *svm, unsigned nr,
 				      bool has_error_code, u32 error_code);
-static u64 __scale_tsc(u64 ratio, u64 tsc);
 
 enum {
 	VMCB_INTERCEPTS, /* Intercept vectors, TSC offset,
@@ -892,21 +891,7 @@
 		kvm_enable_efer_bits(EFER_FFXSR);
 
 	if (boot_cpu_has(X86_FEATURE_TSCRATEMSR)) {
-		u64 max;
-
 		kvm_has_tsc_control = true;
-
-		/*
-		 * Make sure the user can only configure tsc_khz values that
-		 * fit into a signed integer.
-		 * A min value is not calculated needed because it will always
-		 * be 1 on all machines and a value of 0 is used to disable
-		 * tsc-scaling for the vcpu.
-		 */
-		max = min(0x7fffffffULL, __scale_tsc(tsc_khz, TSC_RATIO_MAX));
-
-		kvm_max_guest_tsc_khz = max;
-
 		kvm_max_tsc_scaling_ratio = TSC_RATIO_MAX;
 		kvm_tsc_scaling_ratio_frac_bits = 32;
 	}
@@ -972,31 +957,6 @@
 	seg->base = 0;
 }
 
-static u64 __scale_tsc(u64 ratio, u64 tsc)
-{
-	u64 mult, frac, _tsc;
-
-	mult  = ratio >> 32;
-	frac  = ratio & ((1ULL << 32) - 1);
-
-	_tsc  = tsc;
-	_tsc *= mult;
-	_tsc += (tsc >> 32) * frac;
-	_tsc += ((tsc & ((1ULL << 32) - 1)) * frac) >> 32;
-
-	return _tsc;
-}
-
-static u64 svm_scale_tsc(struct kvm_vcpu *vcpu, u64 tsc)
-{
-	u64 _tsc = tsc;
-
-	if (vcpu->arch.tsc_scaling_ratio != TSC_RATIO_DEFAULT)
-		_tsc = __scale_tsc(vcpu->arch.tsc_scaling_ratio, tsc);
-
-	return _tsc;
-}
-
 static void svm_set_tsc_khz(struct kvm_vcpu *vcpu, u32 user_tsc_khz, bool scale)
 {
 	u64 ratio;
@@ -1065,7 +1025,7 @@
 	if (host) {
 		if (vcpu->arch.tsc_scaling_ratio != TSC_RATIO_DEFAULT)
 			WARN_ON(adjustment < 0);
-		adjustment = svm_scale_tsc(vcpu, (u64)adjustment);
+		adjustment = kvm_scale_tsc(vcpu, (u64)adjustment);
 	}
 
 	svm->vmcb->control.tsc_offset += adjustment;
@@ -1083,7 +1043,7 @@
 {
 	u64 tsc;
 
-	tsc = svm_scale_tsc(vcpu, rdtsc());
+	tsc = kvm_scale_tsc(vcpu, rdtsc());
 
 	return target_tsc - tsc;
 }
@@ -3075,7 +3035,7 @@
 {
 	struct vmcb *vmcb = get_host_vmcb(to_svm(vcpu));
 	return vmcb->control.tsc_offset +
-		svm_scale_tsc(vcpu, host_tsc);
+		kvm_scale_tsc(vcpu, host_tsc);
 }
 
 static int svm_get_msr(struct kvm_vcpu *vcpu, struct msr_data *msr_info)
@@ -3085,7 +3045,7 @@
 	switch (msr_info->index) {
 	case MSR_IA32_TSC: {
 		msr_info->data = svm->vmcb->control.tsc_offset +
-			svm_scale_tsc(vcpu, rdtsc());
+			kvm_scale_tsc(vcpu, rdtsc());
 
 		break;
 	}
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index ef5b9d6..1473e64 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -1329,6 +1329,33 @@
 	vcpu->arch.ia32_tsc_adjust_msr += offset - curr_offset;
 }
 
+/*
+ * Multiply tsc by a fixed point number represented by ratio.
+ *
+ * The most significant 64-N bits (mult) of ratio represent the
+ * integral part of the fixed point number; the remaining N bits
+ * (frac) represent the fractional part, ie. ratio represents a fixed
+ * point number (mult + frac * 2^(-N)).
+ *
+ * N equals to kvm_tsc_scaling_ratio_frac_bits.
+ */
+static inline u64 __scale_tsc(u64 ratio, u64 tsc)
+{
+	return mul_u64_u64_shr(tsc, ratio, kvm_tsc_scaling_ratio_frac_bits);
+}
+
+u64 kvm_scale_tsc(struct kvm_vcpu *vcpu, u64 tsc)
+{
+	u64 _tsc = tsc;
+	u64 ratio = vcpu->arch.tsc_scaling_ratio;
+
+	if (ratio != kvm_default_tsc_scaling_ratio)
+		_tsc = __scale_tsc(ratio, tsc);
+
+	return _tsc;
+}
+EXPORT_SYMBOL_GPL(kvm_scale_tsc);
+
 void kvm_write_tsc(struct kvm_vcpu *vcpu, struct msr_data *msr)
 {
 	struct kvm *kvm = vcpu->kvm;
@@ -7371,8 +7398,19 @@
 	if (r != 0)
 		return r;
 
-	if (kvm_has_tsc_control)
+	if (kvm_has_tsc_control) {
+		/*
+		 * Make sure the user can only configure tsc_khz values that
+		 * fit into a signed integer.
+		 * A min value is not calculated needed because it will always
+		 * be 1 on all machines.
+		 */
+		u64 max = min(0x7fffffffULL,
+			      __scale_tsc(kvm_max_tsc_scaling_ratio, tsc_khz));
+		kvm_max_guest_tsc_khz = max;
+
 		kvm_default_tsc_scaling_ratio = 1ULL << kvm_tsc_scaling_ratio_frac_bits;
+	}
 
 	kvm_init_msr_list();
 	return 0;
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 242a6d2..5706a21 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -1183,4 +1183,5 @@
 int kvm_arch_update_irqfd_routing(struct kvm *kvm, unsigned int host_irq,
 				  uint32_t guest_irq, bool set);
 #endif /* CONFIG_HAVE_KVM_IRQ_BYPASS */
+
 #endif
diff --git a/include/linux/math64.h b/include/linux/math64.h
index c45c089..44282ec 100644
--- a/include/linux/math64.h
+++ b/include/linux/math64.h
@@ -142,6 +142,13 @@
 }
 #endif /* mul_u64_u32_shr */
 
+#ifndef mul_u64_u64_shr
+static inline u64 mul_u64_u64_shr(u64 a, u64 mul, unsigned int shift)
+{
+	return (u64)(((unsigned __int128)a * mul) >> shift);
+}
+#endif /* mul_u64_u64_shr */
+
 #else
 
 #ifndef mul_u64_u32_shr
@@ -161,6 +168,50 @@
 }
 #endif /* mul_u64_u32_shr */
 
+#ifndef mul_u64_u64_shr
+static inline u64 mul_u64_u64_shr(u64 a, u64 b, unsigned int shift)
+{
+	union {
+		u64 ll;
+		struct {
+#ifdef __BIG_ENDIAN
+			u32 high, low;
+#else
+			u32 low, high;
+#endif
+		} l;
+	} rl, rm, rn, rh, a0, b0;
+	u64 c;
+
+	a0.ll = a;
+	b0.ll = b;
+
+	rl.ll = (u64)a0.l.low * b0.l.low;
+	rm.ll = (u64)a0.l.low * b0.l.high;
+	rn.ll = (u64)a0.l.high * b0.l.low;
+	rh.ll = (u64)a0.l.high * b0.l.high;
+
+	/*
+	 * Each of these lines computes a 64-bit intermediate result into "c",
+	 * starting at bits 32-95.  The low 32-bits go into the result of the
+	 * multiplication, the high 32-bits are carried into the next step.
+	 */
+	rl.l.high = c = (u64)rl.l.high + rm.l.low + rn.l.low;
+	rh.l.low = c = (c >> 32) + rm.l.high + rn.l.high + rh.l.low;
+	rh.l.high = (c >> 32) + rh.l.high;
+
+	/*
+	 * The 128-bit result of the multiplication is in rl.ll and rh.ll,
+	 * shift it right and throw away the high part of the result.
+	 */
+	if (shift == 0)
+		return rl.ll;
+	if (shift < 64)
+		return (rl.ll >> shift) | (rh.ll << (64 - shift));
+	return rh.ll >> (shift & 63);
+}
+#endif /* mul_u64_u64_shr */
+
 #endif
 
 #endif /* _LINUX_MATH64_H */