test_refactor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Unit tests for refactor.py.
  3. """
  4. import sys
  5. import os
  6. import codecs
  7. import io
  8. import re
  9. import tempfile
  10. import shutil
  11. import unittest
  12. from lib2to3 import refactor, pygram, fixer_base
  13. from lib2to3.pgen2 import token
  14. TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
  15. FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers")
  16. sys.path.append(FIXER_DIR)
  17. try:
  18. _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes")
  19. finally:
  20. sys.path.pop()
  21. _2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes")
  22. class TestRefactoringTool(unittest.TestCase):
  23. def setUp(self):
  24. sys.path.append(FIXER_DIR)
  25. def tearDown(self):
  26. sys.path.pop()
  27. def check_instances(self, instances, classes):
  28. for inst, cls in zip(instances, classes):
  29. if not isinstance(inst, cls):
  30. self.fail("%s are not instances of %s" % instances, classes)
  31. def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None):
  32. return refactor.RefactoringTool(fixers, options, explicit)
  33. def test_print_function_option(self):
  34. rt = self.rt({"print_function" : True})
  35. self.assertIs(rt.grammar, pygram.python_grammar_no_print_statement)
  36. self.assertIs(rt.driver.grammar,
  37. pygram.python_grammar_no_print_statement)
  38. def test_write_unchanged_files_option(self):
  39. rt = self.rt()
  40. self.assertFalse(rt.write_unchanged_files)
  41. rt = self.rt({"write_unchanged_files" : True})
  42. self.assertTrue(rt.write_unchanged_files)
  43. def test_fixer_loading_helpers(self):
  44. contents = ["explicit", "first", "last", "parrot", "preorder"]
  45. non_prefixed = refactor.get_all_fix_names("myfixes")
  46. prefixed = refactor.get_all_fix_names("myfixes", False)
  47. full_names = refactor.get_fixers_from_package("myfixes")
  48. self.assertEqual(prefixed, ["fix_" + name for name in contents])
  49. self.assertEqual(non_prefixed, contents)
  50. self.assertEqual(full_names,
  51. ["myfixes.fix_" + name for name in contents])
  52. def test_detect_future_features(self):
  53. run = refactor._detect_future_features
  54. fs = frozenset
  55. empty = fs()
  56. self.assertEqual(run(""), empty)
  57. self.assertEqual(run("from __future__ import print_function"),
  58. fs(("print_function",)))
  59. self.assertEqual(run("from __future__ import generators"),
  60. fs(("generators",)))
  61. self.assertEqual(run("from __future__ import generators, feature"),
  62. fs(("generators", "feature")))
  63. inp = "from __future__ import generators, print_function"
  64. self.assertEqual(run(inp), fs(("generators", "print_function")))
  65. inp ="from __future__ import print_function, generators"
  66. self.assertEqual(run(inp), fs(("print_function", "generators")))
  67. inp = "from __future__ import (print_function,)"
  68. self.assertEqual(run(inp), fs(("print_function",)))
  69. inp = "from __future__ import (generators, print_function)"
  70. self.assertEqual(run(inp), fs(("generators", "print_function")))
  71. inp = "from __future__ import (generators, nested_scopes)"
  72. self.assertEqual(run(inp), fs(("generators", "nested_scopes")))
  73. inp = """from __future__ import generators
  74. from __future__ import print_function"""
  75. self.assertEqual(run(inp), fs(("generators", "print_function")))
  76. invalid = ("from",
  77. "from 4",
  78. "from x",
  79. "from x 5",
  80. "from x im",
  81. "from x import",
  82. "from x import 4",
  83. )
  84. for inp in invalid:
  85. self.assertEqual(run(inp), empty)
  86. inp = "'docstring'\nfrom __future__ import print_function"
  87. self.assertEqual(run(inp), fs(("print_function",)))
  88. inp = "'docstring'\n'somng'\nfrom __future__ import print_function"
  89. self.assertEqual(run(inp), empty)
  90. inp = "# comment\nfrom __future__ import print_function"
  91. self.assertEqual(run(inp), fs(("print_function",)))
  92. inp = "# comment\n'doc'\nfrom __future__ import print_function"
  93. self.assertEqual(run(inp), fs(("print_function",)))
  94. inp = "class x: pass\nfrom __future__ import print_function"
  95. self.assertEqual(run(inp), empty)
  96. def test_get_headnode_dict(self):
  97. class NoneFix(fixer_base.BaseFix):
  98. pass
  99. class FileInputFix(fixer_base.BaseFix):
  100. PATTERN = "file_input< any * >"
  101. class SimpleFix(fixer_base.BaseFix):
  102. PATTERN = "'name'"
  103. no_head = NoneFix({}, [])
  104. with_head = FileInputFix({}, [])
  105. simple = SimpleFix({}, [])
  106. d = refactor._get_headnode_dict([no_head, with_head, simple])
  107. top_fixes = d.pop(pygram.python_symbols.file_input)
  108. self.assertEqual(top_fixes, [with_head, no_head])
  109. name_fixes = d.pop(token.NAME)
  110. self.assertEqual(name_fixes, [simple, no_head])
  111. for fixes in d.values():
  112. self.assertEqual(fixes, [no_head])
  113. def test_fixer_loading(self):
  114. from myfixes.fix_first import FixFirst
  115. from myfixes.fix_last import FixLast
  116. from myfixes.fix_parrot import FixParrot
  117. from myfixes.fix_preorder import FixPreorder
  118. rt = self.rt()
  119. pre, post = rt.get_fixers()
  120. self.check_instances(pre, [FixPreorder])
  121. self.check_instances(post, [FixFirst, FixParrot, FixLast])
  122. def test_naughty_fixers(self):
  123. self.assertRaises(ImportError, self.rt, fixers=["not_here"])
  124. self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"])
  125. self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"])
  126. def test_refactor_string(self):
  127. rt = self.rt()
  128. input = "def parrot(): pass\n\n"
  129. tree = rt.refactor_string(input, "<test>")
  130. self.assertNotEqual(str(tree), input)
  131. input = "def f(): pass\n\n"
  132. tree = rt.refactor_string(input, "<test>")
  133. self.assertEqual(str(tree), input)
  134. def test_refactor_stdin(self):
  135. class MyRT(refactor.RefactoringTool):
  136. def print_output(self, old_text, new_text, filename, equal):
  137. results.extend([old_text, new_text, filename, equal])
  138. results = []
  139. rt = MyRT(_DEFAULT_FIXERS)
  140. save = sys.stdin
  141. sys.stdin = io.StringIO("def parrot(): pass\n\n")
  142. try:
  143. rt.refactor_stdin()
  144. finally:
  145. sys.stdin = save
  146. expected = ["def parrot(): pass\n\n",
  147. "def cheese(): pass\n\n",
  148. "<stdin>", False]
  149. self.assertEqual(results, expected)
  150. def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS,
  151. options=None, mock_log_debug=None,
  152. actually_write=True):
  153. test_file = self.init_test_file(test_file)
  154. old_contents = self.read_file(test_file)
  155. rt = self.rt(fixers=fixers, options=options)
  156. if mock_log_debug:
  157. rt.log_debug = mock_log_debug
  158. rt.refactor_file(test_file)
  159. self.assertEqual(old_contents, self.read_file(test_file))
  160. if not actually_write:
  161. return
  162. rt.refactor_file(test_file, True)
  163. new_contents = self.read_file(test_file)
  164. self.assertNotEqual(old_contents, new_contents)
  165. return new_contents
  166. def init_test_file(self, test_file):
  167. tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor")
  168. self.addCleanup(shutil.rmtree, tmpdir)
  169. shutil.copy(test_file, tmpdir)
  170. test_file = os.path.join(tmpdir, os.path.basename(test_file))
  171. os.chmod(test_file, 0o644)
  172. return test_file
  173. def read_file(self, test_file):
  174. with open(test_file, "rb") as fp:
  175. return fp.read()
  176. def refactor_file(self, test_file, fixers=_2TO3_FIXERS):
  177. test_file = self.init_test_file(test_file)
  178. old_contents = self.read_file(test_file)
  179. rt = self.rt(fixers=fixers)
  180. rt.refactor_file(test_file, True)
  181. new_contents = self.read_file(test_file)
  182. return old_contents, new_contents
  183. def test_refactor_file(self):
  184. test_file = os.path.join(FIXER_DIR, "parrot_example.py")
  185. self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
  186. def test_refactor_file_write_unchanged_file(self):
  187. test_file = os.path.join(FIXER_DIR, "parrot_example.py")
  188. debug_messages = []
  189. def recording_log_debug(msg, *args):
  190. debug_messages.append(msg % args)
  191. self.check_file_refactoring(test_file, fixers=(),
  192. options={"write_unchanged_files": True},
  193. mock_log_debug=recording_log_debug,
  194. actually_write=False)
  195. # Testing that it logged this message when write=False was passed is
  196. # sufficient to see that it did not bail early after "No changes".
  197. message_regex = r"Not writing changes to .*%s" % \
  198. re.escape(os.sep + os.path.basename(test_file))
  199. for message in debug_messages:
  200. if "Not writing changes" in message:
  201. self.assertRegex(message, message_regex)
  202. break
  203. else:
  204. self.fail("%r not matched in %r" % (message_regex, debug_messages))
  205. def test_refactor_dir(self):
  206. def check(structure, expected):
  207. def mock_refactor_file(self, f, *args):
  208. got.append(f)
  209. save_func = refactor.RefactoringTool.refactor_file
  210. refactor.RefactoringTool.refactor_file = mock_refactor_file
  211. rt = self.rt()
  212. got = []
  213. dir = tempfile.mkdtemp(prefix="2to3-test_refactor")
  214. try:
  215. os.mkdir(os.path.join(dir, "a_dir"))
  216. for fn in structure:
  217. open(os.path.join(dir, fn), "wb").close()
  218. rt.refactor_dir(dir)
  219. finally:
  220. refactor.RefactoringTool.refactor_file = save_func
  221. shutil.rmtree(dir)
  222. self.assertEqual(got,
  223. [os.path.join(dir, path) for path in expected])
  224. check([], [])
  225. tree = ["nothing",
  226. "hi.py",
  227. ".dumb",
  228. ".after.py",
  229. "notpy.npy",
  230. "sappy"]
  231. expected = ["hi.py"]
  232. check(tree, expected)
  233. tree = ["hi.py",
  234. os.path.join("a_dir", "stuff.py")]
  235. check(tree, tree)
  236. def test_file_encoding(self):
  237. fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
  238. self.check_file_refactoring(fn)
  239. def test_false_file_encoding(self):
  240. fn = os.path.join(TEST_DATA_DIR, "false_encoding.py")
  241. data = self.check_file_refactoring(fn)
  242. def test_bom(self):
  243. fn = os.path.join(TEST_DATA_DIR, "bom.py")
  244. data = self.check_file_refactoring(fn)
  245. self.assertTrue(data.startswith(codecs.BOM_UTF8))
  246. def test_crlf_newlines(self):
  247. old_sep = os.linesep
  248. os.linesep = "\r\n"
  249. try:
  250. fn = os.path.join(TEST_DATA_DIR, "crlf.py")
  251. fixes = refactor.get_fixers_from_package("lib2to3.fixes")
  252. self.check_file_refactoring(fn, fixes)
  253. finally:
  254. os.linesep = old_sep
  255. def test_crlf_unchanged(self):
  256. fn = os.path.join(TEST_DATA_DIR, "crlf.py")
  257. old, new = self.refactor_file(fn)
  258. self.assertIn(b"\r\n", old)
  259. self.assertIn(b"\r\n", new)
  260. self.assertNotIn(b"\r\r\n", new)
  261. def test_refactor_docstring(self):
  262. rt = self.rt()
  263. doc = """
  264. >>> example()
  265. 42
  266. """
  267. out = rt.refactor_docstring(doc, "<test>")
  268. self.assertEqual(out, doc)
  269. doc = """
  270. >>> def parrot():
  271. ... return 43
  272. """
  273. out = rt.refactor_docstring(doc, "<test>")
  274. self.assertNotEqual(out, doc)
  275. def test_explicit(self):
  276. from myfixes.fix_explicit import FixExplicit
  277. rt = self.rt(fixers=["myfixes.fix_explicit"])
  278. self.assertEqual(len(rt.post_order), 0)
  279. rt = self.rt(explicit=["myfixes.fix_explicit"])
  280. for fix in rt.post_order:
  281. if isinstance(fix, FixExplicit):
  282. break
  283. else:
  284. self.fail("explicit fixer not loaded")