| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- """
- Deterministic linear-time selection (median-of-medians) example.
- 提供一个单函数 `select_kth(arr, k)`:返回无序列表中第 k 小的元素(0-based)。
- 包含自检:随机数组测试与边界情况,确保在 CPython 下运行通过。
- """
- import random
- def select_kth(arr, k):
- """Return the k-th smallest element of arr (0-based index) using
- the median-of-medians algorithm to guarantee linear time.
- Args:
- arr: list of comparable items
- k: int, 0 <= k < len(arr)
- Returns:
- The k-th smallest element.
- Raises:
- ValueError: if k is out of range or arr is empty.
- """
- if not arr:
- raise ValueError('empty array')
- if k < 0 or k >= len(arr):
- raise ValueError('k out of range')
- def partition(a, pivot):
- lows = []
- highs = []
- pivots = []
- for x in a:
- if x < pivot:
- lows.append(x)
- elif x > pivot:
- highs.append(x)
- else:
- pivots.append(x)
- return lows, pivots, highs
- def median_of_medians(a):
- # divide a into groups of 5
- groups = [a[i:i+5] for i in range(0, len(a), 5)]
- medians = []
- for g in groups:
- g.sort()
- medians.append(g[len(g)//2])
- if len(medians) <= 5:
- medians.sort()
- return medians[len(medians)//2]
- # recursively find median of medians
- return select_median(medians)
- def select_median(a):
- # helper to select median index
- return median_of_medians(a)
- def select(a, k):
- if len(a) <= 10:
- # small array: sort and return
- b = sorted(a)
- return b[k]
- pivot = median_of_medians(a)
- lows, pivots, highs = partition(a, pivot)
- if k < len(lows):
- return select(lows, k)
- elif k < len(lows) + len(pivots):
- return pivots[0]
- else:
- return select(highs, k - len(lows) - len(pivots))
- return select(list(arr), k)
- def _test_small_cases():
- # basic sanity checks
- assert select_kth([1], 0) == 1
- assert select_kth([2, 1], 0) == 1
- assert select_kth([2, 1], 1) == 2
- assert select_kth([3,1,2], 1) == 2
- assert select_kth([5,4,3,2,1], 0) == 1
- assert select_kth([5,4,3,2,1], 4) == 5
- def _test_randomized(n=100, trials=50):
- for _ in range(trials):
- length = random.randint(1, n)
- arr = [random.randint(-1000, 1000) for _ in range(length)]
- k = random.randrange(length)
- expected = sorted(arr)[k]
- got = select_kth(arr, k)
- if got != expected:
- print('FAILED on', arr, 'k=', k, 'expected=', expected, 'got=', got)
- raise AssertionError('select_kth mismatch')
- def _test_edge_cases():
- # duplicates
- arr = [5] * 10
- assert select_kth(arr, 0) == 5
- assert select_kth(arr, 9) == 5
- if __name__ == '__main__':
- print('Running select_kth self-tests...')
- _test_small_cases()
- _test_edge_cases()
- _test_randomized(n=200, trials=100)
- print('All tests passed for select_kth')
|