select_kth.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """
  2. Deterministic linear-time selection (median-of-medians) example.
  3. 提供一个单函数 `select_kth(arr, k)`:返回无序列表中第 k 小的元素(0-based)。
  4. 包含自检:随机数组测试与边界情况,确保在 CPython 下运行通过。
  5. """
  6. import random
  7. def select_kth(arr, k):
  8. """Return the k-th smallest element of arr (0-based index) using
  9. the median-of-medians algorithm to guarantee linear time.
  10. Args:
  11. arr: list of comparable items
  12. k: int, 0 <= k < len(arr)
  13. Returns:
  14. The k-th smallest element.
  15. Raises:
  16. ValueError: if k is out of range or arr is empty.
  17. """
  18. if not arr:
  19. raise ValueError('empty array')
  20. if k < 0 or k >= len(arr):
  21. raise ValueError('k out of range')
  22. def partition(a, pivot):
  23. lows = []
  24. highs = []
  25. pivots = []
  26. for x in a:
  27. if x < pivot:
  28. lows.append(x)
  29. elif x > pivot:
  30. highs.append(x)
  31. else:
  32. pivots.append(x)
  33. return lows, pivots, highs
  34. def median_of_medians(a):
  35. # divide a into groups of 5
  36. groups = [a[i:i+5] for i in range(0, len(a), 5)]
  37. medians = []
  38. for g in groups:
  39. g.sort()
  40. medians.append(g[len(g)//2])
  41. if len(medians) <= 5:
  42. medians.sort()
  43. return medians[len(medians)//2]
  44. # recursively find median of medians
  45. return select_median(medians)
  46. def select_median(a):
  47. # helper to select median index
  48. return median_of_medians(a)
  49. def select(a, k):
  50. if len(a) <= 10:
  51. # small array: sort and return
  52. b = sorted(a)
  53. return b[k]
  54. pivot = median_of_medians(a)
  55. lows, pivots, highs = partition(a, pivot)
  56. if k < len(lows):
  57. return select(lows, k)
  58. elif k < len(lows) + len(pivots):
  59. return pivots[0]
  60. else:
  61. return select(highs, k - len(lows) - len(pivots))
  62. return select(list(arr), k)
  63. def _test_small_cases():
  64. # basic sanity checks
  65. assert select_kth([1], 0) == 1
  66. assert select_kth([2, 1], 0) == 1
  67. assert select_kth([2, 1], 1) == 2
  68. assert select_kth([3,1,2], 1) == 2
  69. assert select_kth([5,4,3,2,1], 0) == 1
  70. assert select_kth([5,4,3,2,1], 4) == 5
  71. def _test_randomized(n=100, trials=50):
  72. for _ in range(trials):
  73. length = random.randint(1, n)
  74. arr = [random.randint(-1000, 1000) for _ in range(length)]
  75. k = random.randrange(length)
  76. expected = sorted(arr)[k]
  77. got = select_kth(arr, k)
  78. if got != expected:
  79. print('FAILED on', arr, 'k=', k, 'expected=', expected, 'got=', got)
  80. raise AssertionError('select_kth mismatch')
  81. def _test_edge_cases():
  82. # duplicates
  83. arr = [5] * 10
  84. assert select_kth(arr, 0) == 5
  85. assert select_kth(arr, 9) == 5
  86. if __name__ == '__main__':
  87. print('Running select_kth self-tests...')
  88. _test_small_cases()
  89. _test_edge_cases()
  90. _test_randomized(n=200, trials=100)
  91. print('All tests passed for select_kth')