[LeetCode刷题笔记]4 - 寻找两个正序数组的中位数(归并 / 递归 / 二分查找)
一、题目描述
- 给定两个大小分别为 m 和 n 的正序(排好序了,从小到大)数组 nums1 和 nums2。请你找出并返回其 中位数 (这里指的是合并成一个数组后的中位数)。
- 要求所用的算法时间复杂度应该在 O ( l o g ( m + n ) ) O(log (m+n)) O(log(m+n)) 以内。
示例:
输入 | 输出 |
---|---|
n u m s 1 = [ 1 , 3 ] , n u m s 2 = [ 2 ] nums1 = [1,3], nums2 = [2] nums1=[1,3],nums2=[2] | 2.00000 2.00000 2.00000 |
n u m s 1 = [ 1 , 2 ] , n u m s 2 = [ 3 , 4 ] nums1 = [1,2], nums2 = [3,4] nums1=[1,2],nums2=[3,4] | 2.50000 2.50000 2.50000 |
提示:
- n u m s 1. l e n g t h = m nums1.length = m nums1.length=m
- n u m s 2. l e n g t h = n nums2.length = n nums2.length=n
- 0 < = m < = 1000 0 <= m <= 1000 0<=m<=1000
- 0 < = n < = 1000 0 <= n <= 1000 0<=n<=1000
- 1 < = m + n < = 2000 1 <= m + n <= 2000 1<=m+n<=2000
- − 1 0 6 < = n u m s 1 [ i ] , n u m s 2 [ i ] < = 1 0 6 -10^6 <= nums1[i], nums2[i] <= 10^6 −106<=nums1[i],nums2[i]<=106
二、求解思路
思路一:归并( O ( m + n ) O(m+n) O(m+n))
- 使用归并的方式,合并两个有序数组,得到一个大的有序数组。大的有序数组的中间位置的元素(这里需要根据元素个数做个判断),即为中位数。
- 此时时间复杂度为 O ( m + n ) O(m+n) O(m+n),肯定不符合题意,需要进行改进。
C++代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
vector<int> mg; // merge后的有序大数组
int i = 0, j = 0; // 双指针归并算法
while (i < m && j < n) {
// 把小的优先放入mg
if (nums1[i] < nums2[j]) mg.push_back(nums1[i++]);
else mg.push_back(nums2[j++]);
}
while (i < m) mg.push_back(nums1[i++]); // 把nums1剩余元素直接依次放入mg
while (j < n) mg.push_back(nums2[j++]); // 把nums2剩余元素直接依次放入mg
// 找到中位数
int k = mg.size();
double res;
if (k % 2) res = mg[(k - 1) >> 1];
else res = 1.0*(mg[k/2 - 1] + mg[k/2])/2;
return res;
}
};
复杂度分析
- 时间复杂度: O ( m + n ) O(m+n) O(m+n)。
- 空间复杂度: O ( m + n ) O(m+n) O(m+n)。
思路二:递归( O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)))
原问题难以直接递归求解,所以我们先考虑这样一个问题:
在两个有序数组中,找出第 k k k 小的数。
如果该问题可以解决,那么第 ( n + m ) / 2 (n+m)/2 (n+m)/2 小的数就是我们要求的中位数。
先从简单情况入手,假设 m , n ≥ k / 2 m,n≥k/2 m,n≥k/2,我们先从 n u m s 1 nums1 nums1 和 n u m s 2 nums2 nums2 中各取前 k / 2 k/2 k/2 个元素:
- 如果 n u m s 1 [ k / 2 − 1 ] > n u m s 2 [ k / 2 − 1 ] nums1\left[k/2−1\right]>nums2\left[k/2−1\right] nums1[k/2−1]>nums2[k/2−1],则说明 n u m s 1 nums1 nums1 中取的元素过多, n u m s 2 nums2 nums2 中取的元素过少;因此 n u m s 2 nums2 nums2 中的前 k / 2 k/2 k/2 个元素一定都小于等于第 k k k 小数,所以我们可以先取出这些数,将问题归约成在剩下的数中找第 k − ⌊ k / 2 ⌋ k−⌊k/2⌋ k−⌊k/2⌋ 小数。
- 如果 n u m s 1 [ k / 2 − 1 ] ≤ n u m s 2 [ k / 2 − 1 ] nums1\left[k/2−1\right]≤nums2\left[k/2−1\right] nums1[k/2−1]≤nums2[k/2−1],同理可说明 n u m s 1 nums1 nums1 中的前 k / 2 k/2 k/2 个元素一定都小于等于第 k k k 小数,类似可将问题的规模减少一半。
现在考虑边界情况,如果 m < k / 2 m<k/2 m<k/2,则我们从 n u m s 1 nums1 nums1 中取 m m m 个元素,从 n u m s 2 nums2 nums2 中取 k / 2 k/2 k/2 个元素(由于 k = ( n + m ) / 2 k=(n+m)/2 k=(n+m)/2,因此 m , n m,n m,n 不可能同时小于 k / 2 k/2 k/2:
- 如果 n u m s 1 [ m − 1 ] > n u m s 2 [ k / 2 − 1 ] nums1[m−1]>nums2[k/2−1] nums1[m−1]>nums2[k/2−1],则 n u m s 2 nums2 nums2 中的前 k / 2 k/2 k/2 个元素一定都小于等于第 k k k 小数,我们可以将问题归约成在剩下的数中找第 k − ⌊ k / 2 ⌋ k−⌊k/2⌋ k−⌊k/2⌋ 小数。
- 如果 n u m s 1 [ m − 1 ] ≤ n u m s 2 [ k / 2 − 1 ] nums1[m−1]≤nums2[k/2−1] nums1[m−1]≤nums2[k/2−1],则 n u m s 1 nums1 nums1 中的所有元素一定都小于等于第 k k k 小数,因此第 k k k小数是 n u m s 2 [ k − m − 1 ] nums2[k−m−1] nums2[k−m−1]。
C++代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
// 计算长度总和,从而获得中位数的位置
int total = nums1.size() + nums2.size();
if (total % 2 == 0) {
// 如果total为偶数,则中位数是中间两个数的均值,位置分别在total/2和total/2+1
// 注意:这里的位置均是从1开始的,比如说6个数,下标从1-6,中位数位置为第3和4个数的均值
// 0 表示从数组的哪个下标开始搜索
int left = findKthNumber(nums1, 0, nums2, 0, total / 2);
int right = findKthNumber(nums1, 0, nums2, 0, total / 2 + 1);
return (left + right) / 2.0;
}
else {
// 如果total为奇数,则中位数是位置为第total/2+1的数(这里是整数向下除法)
return findKthNumber(nums1, 0, nums2, 0, total / 2 + 1);
}
}
// 该函数作用:下标分别从i和j开始,搜索两个数组第k小的数(这里的第k小是指下标从1开始的第k小)
int findKthNumber(vector<int> &nums1, int i, vector<int> &nums2, int j, int k) {
// 这里默认nums1的长度比nums2的长度短,否则就调换一下
if (nums1.size() - i > nums2.size() - j) return findKthNumber(nums2, j, nums1, i, k);
// 边界情况:nums1为空,则直接返回nums2的第j+k-1位置上的数
if (nums1.size() == i) return nums2[j + k - 1];
// 边界情况:k==1,即找最小值即可。注意:此时nums1和nums2一定都不为空
if (k == 1) return min(nums1[i], nums2[j]);
// 一般情况:先确定两个数组第k/2位置在数组中的下标是多少(si和sj)
// 注意:由于nums1比较短,因此i+k/2可能会越界,需要和nums1.size()取个最小值
int si = min(i + k / 2, int(nums1.size())), sj = j + k / 2;
if (nums1[si - 1] > nums2[sj - 1]) {
// 如果nums1第si位置上的数大于nums2第sj位置上的数,则可以删除nums2前sj个数
// 因此,下一次迭代nums2从sj开始,且搜寻第k-k/2小的数
return findKthNumber(nums1, i, nums2, sj, k - k / 2);
} else {
// 如果nums1第si位置上的数小于nums2第sj位置上的数,则可以删除nums1前si个数
// 因此,下一次迭代nums1从si开始,且搜寻第k-(si-i)小的数
return findKthNumber(nums1, si, nums2, j, k - (si - i));
}
}
};
复杂度分析
- 时间复杂度: O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n))。 k = ( m + n ) / 2 k=(m+n)/2 k=(m+n)/2,且每次递归 k k k 的规模都减少一半,因此时间复杂度是 O ( l o g ( m + n ) ) O(log(m+n)) O(log(m+n)).
- 空间复杂度: O ( 1 ) O(1) O(1)。
思路三:二分查找( O ( l o g ( m i n ( m + n ) ) ) O(log(min(m+n))) O(log(min(m+n))))
『该算法的处理细节非常繁琐,建议新手直接跳过』
首先,让我们考虑只有一个有序数组的情况:
- 如果我们将有序数组切分为等长的左右两部分,则 中位数 = (左半边的最大值 + 右半边的最小值) / 2。
- 切分情况有两种,例如:
- 数组长度是偶数,对于 [ 2 , 3 , 5 , 7 ] [2,3,5,7] [2,3,5,7], 我们在 3 3 3 和 5 5 5 之间切分: [ 2 , 3 ∣ 5 , 7 ] [2,3| 5,7] [2,3∣5,7],则 中位数 =(3+5)/2;
- 数组长度是奇数,对于 [ 2 , 3 , 4 , 5 , 7 ] [2,3,4,5,7] [2,3,4,5,7],我们在 4 4 4 的位置切分,且让 4 4 4 属于左右两边: [ 2 , 3 , ( 4 ∣ 4 ) , 5 , 7 ] [2,3,(4 | 4),5,7] [2,3,(4∣4),5,7],则 中位数 =(4+4)/2=(4+4)/2;
现在让我们来考虑两个有序数组的情况,类似于上述考虑方式:
- 我们在两个数组中分别找到一个分割点 (分割点可能在相邻数之间,也可能在数上)。两个分割点左边元素的总个数等于右边元素的总个数,且左边元素的最大值 <= 右边元素的最小值,则该分割点即为所求。
为了同时处理分割点『在两数之间』和『在数上』的情况,我们在数组 A 1 A_1 A1中可能是分割点的位置添加虚拟元素 ‘ @ ’ ‘@’ ‘@’,这样我们枚举数组 A 1 ′ A'_1 A1′ 的所有元素,即可枚举 A 1 A_1 A1 所有可能的分割点:
- A 1 : [ 1 , 2 , 3 , 4 , 5 ] = > A 1 ′ : [ @ , 1 , @ , 2 , @ , 3 , @ , 4 , @ , 5 , @ ] A_1:[1,2,3,4,5]=>A'_1:[@,1,@,2,@,3,@,4,@,5,@] A1:[1,2,3,4,5]=>A1′:[@,1,@,2,@,3,@,4,@,5,@]
- A 2 : [ 1 , 1 , 1 , 1 ] = > A 2 ′ : [ @ , 1 , @ , 1 , @ , 1 , @ , 1 , @ ] A_2:[1,1,1,1]=>A'_2:[@,1,@,1,@,1,@,1,@] A2:[1,1,1,1]=>A2′:[@,1,@,1,@,1,@,1,@]
我们将数组 A 1 A_1 A1 的长度记为 N 1 N_1 N1,则 A 1 ′ A'_1 A1′ 的长度为 2 ∗ N 1 + 1 2∗N_1+1 2∗N1+1; A 2 A_2 A2 的长度记为 N 2 N_2 N2,则 A 2 ′ A'_2 A2′ 的长度为 2 ∗ N 2 + 1 2∗N_2+1 2∗N2+1。因此,总共有 2 N 1 + 2 N 2 + 2 2N_1+2N_2+2 2N1+2N2+2 个元素。
假设数组 A 1 ′ A'_1 A1′ 的分割点下标是 C 1 C_1 C1,数组 A 2 ′ A'_2 A2′ 的分割点下标是 C 2 C_2 C2,则 C 1 C_1 C1 和 C 2 C_2 C2 之间具有如下等式关系:
- C 1 + C 2 = N 1 + N 2 C_1+C_2=N_1+N_2 C1+C2=N1+N2
- C 1 + C 2 = N 1 + N 2 C_1+C_2=N_1+N_2 C1+C2=N1+N2
证明:除了 C 1 C_1 C1 和 C 2 C_2 C2 以外,共有 2 N 1 + 2 N 2 2N_1+2N_2 2N1+2N2 个元素,要平均分配到左右两边,因此左边共有 N 1 + N 2 N_1+N_2 N1+N2 个元素。 数组下标从 0 0 0 开始,下标为 C 1 C_1 C1 的元素左边有 C 1 C_1 C1 个元素,下标为 C 2 C_2 C2 的元素左边有 C 2 C_2 C2 个元素,由此可得上述等式。
为了方便表述,在 A 1 ′ A'_1 A1′ 中, C 1 C_1 C1 左边(包括 C 1 C_1 C1)的最大值记为 L 1 L_1 L1, C 1 C_1 C1 右边(包括 C 1 C_1 C1)的最小值记为 R 1 R_1 R1;在 A 2 ′ A'_2 A2′ 中, C 2 C_2 C2 左边(包括 C 2 C_2 C2)的最大值记为 L 2 L_2 L2, C 2 C_2 C2 右边(包括 C 2 C_2 C2)的最小值记为 R 2 R_2 R2。
- L 1 < = R 1 & & L 1 < = R 2 & & L 2 < = R 1 & & L 2 < = R 2 L_1<=R_1\ \&\&\ L_1<=R_2\ \&\&\ L_2<=R_1\ \&\&\ L_2<=R_2 L1<=R1 && L1<=R2 && L2<=R1 && L2<=R2
由于 A 1 , A 2 A_1,A_2 A1,A2 都是有序的,因此 L 1 < = R 1 & & L 2 < = R 2 L_1<=R_1\ \&\&\ L_2<=R_2 L1<=R1 && L2<=R2 一定满足;不满足的情况有两种:
- 如果 L 1 > R 2 L_1>R_2 L1>R2,表示 A 2 A_2 A2中在分割点左侧的元素太少,此时我们需要将 C 2 C_2 C2 右移;
- 如果 L 2 > R 1 L_2>R_1 L2>R1,表示 A 2 A_2 A2中在分割点左侧的元素太多,此时我们需要将 C 2 C_2 C2 左移;
符合二分结构。
另外,我们在实际操作中,不需要真的在原数组中插入 ‘ @ ’ ‘@’ ‘@’,只需找出 L 1 , R 1 , L 2 , R 2 L_1,R_1,L_2,R_2 L1,R1,L2,R2 和 C 1 , C 2 C_1,C_2 C1,C2 的关系即可。我们可以列表找规律:
C 1 C_1 C1 | L 1 L_1 L1 | R 1 R_1 R1 |
---|---|---|
0 | INT_MIN | A1[0] |
1 | A1[0] | A1[0] |
2 | A1[0] | A1[1] |
3 | A1[1] | A1[1] |
4 | A1[1] | A1[2] |
由此我们可以发现:
- L 1 = A 1 [ ( C 1 − 1 ) / 2 ] L_1=A_1[(C_1−1)/2] L1=A1[(C1−1)/2]
- R 1 = A 1 [ C 1 / 2 ] R_1=A_1[C_1/2] R1=A1[C1/2]
类似可得:
- L 2 = A 2 [ ( C 2 − 1 ) / 2 ] L_2=A_2[(C_2−1)/2] L2=A2[(C2−1)/2]
- R 2 = A 2 [ C 2 / 2 ] R_2=A_2[C_2/2] R2=A2[C2/2]
最后,还有两点需要注意:
- 我们只能二分长度较小的数组,因为长度较长的数组中的某些分割点可能不合法,会出现 C 1 > N 1 + N 2 C_1>N_1+N_2 C1>N1+N2 的情况;
- 我们在数组边界设置两个哨兵,来处理
C
1
=
0
C_1=0
C1=0 或
C
1
=
2
N
1
C_1=2N_1
C1=2N1 的情况,这样做并不会影响结果,但可以简化代码:
- A 1 [ − 1 ] = I N T M I N , A 1 [ 2 N 1 ] = I N T M A X A_1[−1]=INTMIN,A_1[2N_1]=INTMAX A1[−1]=INTMIN,A1[2N1]=INTMAX
C++代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int N1 = nums1.size();
int N2 = nums2.size();
if (N1 < N2) return findMedianSortedArrays(nums2, nums1);
int lo = 0, hi = N2 * 2;
while (lo <= hi) {
int mid2 = (lo + hi) / 2;
int mid1 = N1 + N2 - mid2;
double L1 = (mid1 == 0) ? INT_MIN : nums1[(mid1-1)/2];
double L2 = (mid2 == 0) ? INT_MIN : nums2[(mid2-1)/2];
double R1 = (mid1 == N1 * 2) ? INT_MAX : nums1[(mid1)/2];
double R2 = (mid2 == N2 * 2) ? INT_MAX : nums2[(mid2)/2];
if (L1 > R2) lo = mid2 + 1;
else if (L2 > R1) hi = mid2 - 1;
else return (max(L1,L2) + min(R1, R2)) / 2;
}
return -1;
}
};
复杂度分析
- 时间复杂度: O ( l o g ( m i n ( n , m ) ) ) O(log(min(n,m))) O(log(min(n,m)))。二分长度较短的数组,且每次二分的复杂度是 O(1)O(1),所以总复杂度是 O ( l o g ( m i n ( n , m ) ) ) O(log(min(n,m))) O(log(min(n,m)))。
- 空间复杂度: O ( 1 ) O(1) O(1)。