简单的分治排序算法

我们介绍两个基于分治且 O(nlogn) 的排序算法: 快速排序和归并排序. 理解排序算法的最好方法是模拟, 手动模拟指针的移动, 感受数据是怎么逐渐变得有序的. 配合动图食用更佳.

你可能会问: std::sort 它不香吗? 答: 让你手搓是为了理解算法思想, 没让你去用. 顺便提防将来的面试官.

两个算法困难的点在于边界分析. 算法导论就是干这个证明的. 不过正常人记模板就可以了 :)

快速排序

原理

快速排序的核心要义: 任取一个数x作为基准值, 调整x的左右区间, 使得左区间的值小于x, 右区间的值都大于x, 然后对这两个子区间递归调用.

怎么调整区间:

  • 扫描并比较整个数组, 大于 x 的放在数组 a, 小于的放在 b, 最后再复制回去.(空间开销大)
  • 双指针. il 开始移动, 确保 q[l...i-1]x 小. 当遇到比 x 大的值, 就让这个值和 j 交换位置, 然后从后移动 j 指针, 确保 q[j...r]x 大. 遇到比 x 小的值, 就再交换回去, 直到 i, j 相遇.

这两个区间的值也可以等于x. 算法的正确性, 用循环不变式证明: q[l..i] <= x 以及 q[j..r] >= x

性质

朴素的快排是一种不稳定的排序方式.

最优时间复杂度和平均时间复杂度为 \(O(n \log n)\) (基准数是中位数的情况), 最差情况退化成冒泡排序 \(O(n^2)\) (每一次基准数的选取都是最值).

模板

来源: AcWing 785. 快速排序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void quick_sort(int q[], int l, int r)
{
if(l >= r) return;

int i = l - 1, j = r + 1, x = q[l + r >> 1];
while(i < j)
{
do i++; while(q[i] < x);
do j--; while(q[j] > x);
if(i < j) swap(q[i], q[j]);
}

quick_sort(q, l, j), quick_sort(q, j + 1, r);
}

归并排序

归并排序是典范的分治思想的例子. 归并排序的核心要义是: 把数组一分为二, 一直分下去直到不能再分, 然后逐层合并两个有序的数组.

用程序语言描述就是下面的代码:

1
2
3
4
5
6
7
void merge_sort(array) {
if (basecase) return;
merge_sort(left);
merge_sort(right);
merge(left, right);
}

核心是实现子程序 merge: 合并两个排好序的数组. 实现起来很简单, 用到了临时数组, 令 tmp[k] = min(left[i], right[j]) 最后复制回原数组就可以了. 整个过程要用到三个指针, 分别指向临时数组, 以及两个排好序的数组.

板子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void merge_sort(int q[], int l, int r)
{
if(l >= r)
return;
int mid = l + r >> 1;

merge_sort(q, l, mid);
merge_sort(q, mid + 1, r);

int k = 0, i = l, j = mid + 1;
int N = l + r - 1;
int tmp[N];

while (i <= mid && j <= r)
tmp[k++] = (q[i] <= q[j]) ? q[i++] : q[j++];
while(i <= mid)
tmp[k++] = q[i++];
while(j <= r)
tmp[k++] = q[j++];

for(i = 1, j = 0; i <= r; i++, j++)
q[i] = tmp[j];
}

应用

AcWing 788. 逆序对的数量

用归并排序看待这个问题: 逆序对只存在三种情况, 在左半边, 在右半边, 或者横跨左右区间. 假设 merge_sort 能在排序的同时计算出逆序对的数量, 那么前两种情况等于递归调用 merge_sort 的值. 横跨左右区间的情况稍微比较复杂. 利用的是merge区间的单调性. 我们求的是left中每一个元素对应right中逆序对的数量. 假设我们找到了 left[i] > right[j], 那么一定有 left[i...mid] > right[j]. 也就是说存在 mid - i + 1 个逆序对.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
LL merge_sort(int q[], int l, int r)
{
if(l >= r) return 0;
int mid = l + r >> 1;
LL res = merge_sort(q, l, mid) + merge_sort(q, mid + 1, r);

int k = 0, i = l, j = mid + 1;
while(i <= mid && j <= r)
{
if(q[i] <= q[j])
tmp[k++] = q[i++];
else
{
res += mid - i + 1;
tmp[k++] = q[j++];
}
}
while(i <= mid)
tmp[k++] = q[i++];
while(j <= r)
tmp[k++] = q[j++];

for(i = l, j = 0; i <= r; i++, j++)
q[i] = tmp[j];
return res;
}

k-way merge

如果需要merge的数组不止2个, 该怎么实现呢?

  • 建立一个小根堆.
  • 将每一路的第一个元素插入小根堆. 我们可以知道heap[1] 就是最小值.
  • 将堆顶元素弹出, 并将堆顶元素所在数组的下一元素加入堆中.
  • 重复上面两步, 直到每一路都读取结束.

LeetCode 第 23 号问题: 合并 K 个排序链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// struct ListNode
// {
// int val;
// ListNode *next;
// ListNode(int x) : val(x), next(NULL) {}
// };
class Solution
{
public:
struct mycomp
{
bool operator()(ListNode *a, ListNode *b)
{
return a->val > b->val;
}
};
ListNode *mergeKLists(vector<ListNode *> &lists)
{
priority_queue<ListNode *, vector<ListNode *>, mycomp> queue;
for (ListNode *head : lists)
{
if (head)
queue.push(head);
}
ListNode *dummy = new ListNode(-1);
ListNode *temp = dummy;
while (!queue.empty())
{
ListNode *p = queue.top();
queue.pop();
if (p->next)
queue.push(p->next);
temp->next = p;
temp = temp->next;
}
return dummy->next;
}
};

鱼塘钓鱼

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <iostream>
#include <queue>

using namespace std;

typedef pair<int, int> PII;

const int N = 110;
int a[N], d[N], s[N];
priority_queue<PII> q;

int work(int t)
{
int res = 0;
for (int i = 1; i <= t; i++)
{
auto p = q.top();
q.pop();

if (p.first <= 0)
break;
res += p.first;
p.first -= p.second;
q.push(p);
}
return res;
}

int main()
{
int n, T;
cin >> n;

for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= n; i++)
cin >> d[i];
for (int i = 2; i <= n; i++)
{
cin >> s[i];
s[i] += s[i - 1];
}

cin >> T;
int ans = 0;

for (int i = 1; i <= n && s[i] <= T; i++)
{
for (int j = 1; j <= i; j++)
q.push({a[j], d[j]});

ans = max(ans, work(T - s[i]));
}
cout << ans << endl;
return 0;
}