剑指Offer 51

 1 # -*- coding:utf-8 -*-
 2 class Solution:
 3     def __init__(self):
 4         self.cnt = 0
 5         self.tmp = []
 6         
 7     def InversePairs(self, data):
 8         n = len(data)
 9         self.tmp =[0] * n
10         self.mergeSort(data,0,n-1)
11         return self.cnt % 1000000007
12         
13     def mergeSort(self,nums,l,h):
14         if h - l < 1:
15             return
16         m = l + (h - l) // 2
17         self.mergeSort(nums,l,m)
18         self.mergeSort(nums,m+1,h)
19         self.merge(nums,l,m,h)
20         
21     def merge(self,nums,l,m,h):
22         i,j,k = l,m+1,l
23         while i <= m or j <= h:
24             if i > m:
25                 self.tmp[k] = nums[j]
26                 j += 1
27             elif j > h:
28                 self.tmp[k] = nums[i]
29                 i += 1
30             elif nums[i] <= nums[j]:
31                 self.tmp[k] = nums[i]
32                 i += 1
33             else:
34                 self.tmp[k] = nums[j]
35                 j += 1
36                 self.cnt += m - i + 1
37             k += 1
38         k = l
39         while k <= h:
40             nums[k] = self.tmp[k]
41             k += 1
42         # write code here

本题超时,据说是oj对python的判断有问题。

下面是参考的java实现可以提交:

 1 public class Solution {
 2     private long cnt = 0;
 3     private int[] tmp;  // 在这里声明辅助数组,而不是在 merge() 递归函数中声明
 4 
 5     public int InversePairs(int[] nums) {
 6         tmp = new int[nums.length];
 7         mergeSort(nums, 0, nums.length - 1);
 8         return (int) (cnt % 1000000007);
 9     }
10 
11     private void mergeSort(int[] nums, int l, int h) {
12         if (h - l < 1)
13             return;
14         int m = l + (h - l) / 2;
15         mergeSort(nums, l, m);
16         mergeSort(nums, m + 1, h);
17         merge(nums, l, m, h);
18     }
19 
20     private void merge(int[] nums, int l, int m, int h) {
21         int i = l, j = m + 1, k = l;
22         while (i <= m || j <= h) {
23             if (i > m)
24                 tmp[k] = nums[j++];
25             else if (j > h)
26                 tmp[k] = nums[i++];
27             else if (nums[i] <= nums[j])
28                 tmp[k] = nums[i++];
29             else {
30                 tmp[k] = nums[j++];
31                 this.cnt += m - i + 1;  // nums[i] > nums[j],说明 nums[i...mid] 都大于 nums[j]
32         }
33         k++;
34     }
35     for (k = l; k <= h; k++)
36         nums[k] = tmp[k];
37     }
38 }

猜你喜欢

转载自www.cnblogs.com/asenyang/p/11023306.html