Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit db1f428

Browse files
committed
Complete lab2
1 parent 25099a6 commit db1f428

File tree

9 files changed

+212
-10
lines changed

9 files changed

+212
-10
lines changed

lab2/bfprt_select.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
def median(arr: list) -> int:
2+
"""返回 arr 的中位数, arr 的长度不超过 5"""
3+
assert len(arr) <= 5
4+
arr.sort()
5+
return arr[len(arr) // 2]
6+
7+
8+
def partition(arr: list, MoM: int) -> tuple:
9+
"""根据 MoM 划分 arr 为三部分:L, E, G"""
10+
L, E, G = [], [], [] # less, equal, greater
11+
for num in arr:
12+
if num < MoM:
13+
L.append(num)
14+
elif num == MoM:
15+
E.append(num)
16+
else:
17+
G.append(num)
18+
return L, E, G
19+
20+
21+
def bfprt(arr: list) -> int:
22+
"""返回 arr 的中位数的中位数"""
23+
n = len(arr)
24+
if n <= 5:
25+
return median(arr)
26+
m = n // 5
27+
groups = [arr[i * 5:(i + 1) * 5] for i in range(m)]
28+
medians = [median(group) for group in groups]
29+
return bfprt(medians)
30+
31+
32+
def bfprt_select(arr: list, k: int) -> int:
33+
"""返回 arr 中第 k 小的元素"""
34+
# 1. 将 arr 划分为 n//5 组,每组 5 个元素
35+
# 2. 对每个组进行排序,找到其中位数
36+
# 3. 递归地调用 bfprt_select,找到这些中位数的中位数 MoM
37+
# 4. 以 MoM 为基准,划分 arr 为三部分:L, E, G
38+
# 5. 根据 k 与 L, E, G 的大小关系,递归地调用 bfprt_select
39+
# 6. 返回结果
40+
41+
if len(arr) <= 5:
42+
return sorted(arr)[k] # 直接返回第 k 小的元素
43+
MoM = bfprt(arr)
44+
L, E, G = partition(arr, MoM)
45+
if k < len(L):
46+
return bfprt_select(L, k)
47+
elif k < len(L) + len(E):
48+
return E[0]
49+
else:
50+
return bfprt_select(G, k - len(L) - len(E))

lab2/fig/normal.png

162 KB
Loading

lab2/fig/theta.png

154 KB
Loading

lab2/fig/uniform.png

154 KB
Loading

lab2/fig/zipf.png

230 KB
Loading

lab2/gen_data.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
3+
4+
def gen_data(data_type: str, n: int, iter_num: int) -> (list, list):
5+
k_list = [np.random.randint(0, n) for _ in range(iter_num)]
6+
if data_type == "uniform":
7+
return np.random.uniform(0, 1, n).tolist(), k_list
8+
elif data_type == "normal":
9+
return np.random.normal(0, 1, n).tolist(), k_list
10+
elif data_type == "zipf":
11+
return np.random.zipf(2, n).tolist(), k_list

lab2/lazy_select.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import random
2+
from math import sqrt, floor
3+
4+
5+
def rank(arr: list, x: int) -> int:
6+
"""返回 arr 中小于 x 的元素个数"""
7+
return sum(1 for num in arr if num < x)
8+
9+
10+
def min_k(sorted_arr: list, k: int) -> int:
11+
"""返回有序数组 sorted_arr 中第 k 小的元素"""
12+
return sorted_arr[k]
13+
14+
15+
def lazy_select(arr: list, k: int, theta: float = 3 / 4) -> int:
16+
"""拉斯维加斯算法,返回 arr 中第 k 小的元素"""
17+
n = len(arr)
18+
R_len = int(n ** theta)
19+
R = random.choices(arr, k=R_len) # 随机选择 n^(3/4) 个元素
20+
R.sort() # 此排序的时间复杂度为 O(n)
21+
x = (k / n) * R_len # arr 的第 k 小元素可能成为 R 的第 x 小元素
22+
l, h = max(floor(x - sqrt(n)), 0), min(floor(x + sqrt(n)), R_len - 1) # 考察区间 [l, h]
23+
L, H = min_k(R, l), min_k(R, h)
24+
Lp, Hp = rank(arr, L), rank(arr, H)
25+
P = [num for num in arr if L <= num <= H] # 将 arr 中介于 L, H 之间的元素放入 P
26+
27+
if Lp <= k <= Hp and len(P) <= 4 * n ** theta + 1:
28+
P.sort()
29+
return min_k(P, k - Lp)
30+
else:
31+
if R_len < n:
32+
return lazy_select(arr, k, min(theta + 0.05, 1)) # 略微提高 R_len 的大小
33+
else:
34+
arr.sort()
35+
return min_k(arr, k)

lab2/main.py

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,107 @@
1-
# This is a sample Python script.
1+
"""
2+
比较 3 种中位数选择算法的性能
3+
- 算法 1:排序后选择
4+
- 算法 2: 确定型中位数线性时间选择 (BFPRT)
5+
- 算法 3: 中位数选择随机算法
26
3-
# Press Shift+F10 to execute it or replace it with your code.
4-
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
7+
实验内容:
8+
- 实现三种算法
9+
- 数据集自己寻找或生成
10+
- 运行时间比较,准确度比较
11+
- 扩展性比较
12+
- 以恰当、准确、规范的形式表述实验结果
13+
"""
514

15+
import time
16+
import numpy as np
17+
import matplotlib.pyplot as plt
618

7-
def print_hi(name):
8-
# Use a breakpoint in the code line below to debug your script.
9-
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
19+
from sort_select import sort_select
20+
from bfprt_select import bfprt_select
21+
from lazy_select import lazy_select
22+
from gen_data import gen_data
1023

1124

12-
# Press the green button in the gutter to run the script.
13-
if __name__ == '__main__':
14-
print_hi('PyCharm')
25+
def run_select(arr: list, k_list: list, func) -> list:
26+
"""测试选择算法, 返回运行结果"""
27+
result = []
28+
for k in k_list:
29+
result.append(func(arr, k))
30+
return result
1531

16-
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
32+
33+
def test_all_on_data(arr: list, k_list: list):
34+
run_time = []
35+
36+
start_time = time.time()
37+
sort_select_result = run_select(arr, k_list, sort_select)
38+
run_time.append(time.time() - start_time)
39+
40+
start_time = time.time()
41+
bfprt_select_result = run_select(arr, k_list, bfprt_select)
42+
run_time.append(time.time() - start_time)
43+
44+
start_time = time.time()
45+
lazy_select_result = run_select(arr, k_list, lazy_select)
46+
run_time.append(time.time() - start_time)
47+
48+
if (sort_select_result != bfprt_select_result) or (sort_select_result != lazy_select_result):
49+
print("Results are not equal!")
50+
51+
return run_time
52+
53+
54+
def test(data_type: str, n_list: list, iter_num: int):
55+
run_times = [[] for _ in range(3)] # [[sort_select], [bfprt_select], [lazy_select
56+
for n in n_list:
57+
arr, k_list = gen_data(data_type, n, iter_num)
58+
run_time = test_all_on_data(arr, k_list)
59+
for i in range(3):
60+
run_times[i].append(run_time[i] / iter_num)
61+
62+
fig = plt.figure(dpi=400)
63+
ax = fig.add_subplot(111)
64+
ax.plot(n_list, run_times[0], label="sort_select")
65+
ax.plot(n_list, run_times[1], label="bfprt_select")
66+
ax.plot(n_list, run_times[2], label="lazy_select")
67+
ax.set_xlabel("Data Size")
68+
ax.set_ylabel("Run Time")
69+
ax.set_title(("Run Time of Three Select Algorithms on " + data_type + " Data").title())
70+
ax.legend()
71+
plt.show()
72+
73+
74+
def test_theta(n: int, iter_num: int):
75+
theta_list = np.linspace(0.5, 1, 100).tolist()
76+
run_times = []
77+
for theta in theta_list:
78+
arr, k_list = gen_data("uniform", n, iter_num)
79+
start_time = time.time()
80+
for k in k_list:
81+
lazy_select(arr, k, theta)
82+
run_times.append((time.time() - start_time) / iter_num)
83+
84+
fig = plt.figure(dpi=400)
85+
ax = fig.add_subplot(111)
86+
ax.plot(theta_list, run_times)
87+
ax.set_xlabel("Theta")
88+
ax.set_ylabel("Run Time")
89+
ax.set_title("Run Time of Lazy Select Algorithm on Different Theta")
90+
plt.show()
91+
92+
93+
def main():
94+
# 测试 3 种算法的性能和扩展性
95+
iter_num = 3 # 测试次数
96+
n_list = np.linspace(10000, 100000, 20, dtype=int).tolist() # 数据规模
97+
data_type_list = ["uniform", "normal", "zipf"]
98+
for data_type in data_type_list:
99+
test(data_type, n_list, iter_num)
100+
101+
# 测试随机算法中的关键参数 theta 对性能的影响
102+
n = 10000
103+
test_theta(n, iter_num)
104+
105+
106+
if __name__ == "__main__":
107+
main()

lab2/sort_select.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def quick_sort(arr: list) -> list:
2+
"""对 arr 进行快速排序"""
3+
if len(arr) <= 1:
4+
return arr
5+
pivot = arr[len(arr) // 2]
6+
left = [x for x in arr if x < pivot]
7+
middle = [x for x in arr if x == pivot]
8+
right = [x for x in arr if x > pivot]
9+
return quick_sort(left) + middle + quick_sort(right)
10+
11+
12+
def sort_select(arr: list, k: int) -> int:
13+
"""将 arr 排序后,返回其中第 k 小的元素"""
14+
arr = quick_sort(arr)
15+
return arr[k]

0 commit comments

Comments
 (0)