refactor.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. # Copyright 2006 Google, Inc. All Rights Reserved.
  2. # Licensed to PSF under a Contributor Agreement.
  3. """Refactoring framework.
  4. Used as a main program, this can refactor any number of files and/or
  5. recursively descend down directories. Imported as a module, this
  6. provides infrastructure to write your own refactoring tool.
  7. """
  8. __author__ = "Guido van Rossum <guido@python.org>"
  9. # Python imports
  10. import io
  11. import os
  12. import pkgutil
  13. import sys
  14. import logging
  15. import operator
  16. import collections
  17. from itertools import chain
  18. # Local imports
  19. from .pgen2 import driver, tokenize, token
  20. from .fixer_util import find_root
  21. from . import pytree, pygram
  22. from . import btm_matcher as bm
  23. def get_all_fix_names(fixer_pkg, remove_prefix=True):
  24. """Return a sorted list of all available fix names in the given package."""
  25. pkg = __import__(fixer_pkg, [], [], ["*"])
  26. fix_names = []
  27. for finder, name, ispkg in pkgutil.iter_modules(pkg.__path__):
  28. if name.startswith("fix_"):
  29. if remove_prefix:
  30. name = name[4:]
  31. fix_names.append(name)
  32. return fix_names
  33. class _EveryNode(Exception):
  34. pass
  35. def _get_head_types(pat):
  36. """ Accepts a pytree Pattern Node and returns a set
  37. of the pattern types which will match first. """
  38. if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
  39. # NodePatters must either have no type and no content
  40. # or a type and content -- so they don't get any farther
  41. # Always return leafs
  42. if pat.type is None:
  43. raise _EveryNode
  44. return {pat.type}
  45. if isinstance(pat, pytree.NegatedPattern):
  46. if pat.content:
  47. return _get_head_types(pat.content)
  48. raise _EveryNode # Negated Patterns don't have a type
  49. if isinstance(pat, pytree.WildcardPattern):
  50. # Recurse on each node in content
  51. r = set()
  52. for p in pat.content:
  53. for x in p:
  54. r.update(_get_head_types(x))
  55. return r
  56. raise Exception("Oh no! I don't understand pattern %s" %(pat))
  57. def _get_headnode_dict(fixer_list):
  58. """ Accepts a list of fixers and returns a dictionary
  59. of head node type --> fixer list. """
  60. head_nodes = collections.defaultdict(list)
  61. every = []
  62. for fixer in fixer_list:
  63. if fixer.pattern:
  64. try:
  65. heads = _get_head_types(fixer.pattern)
  66. except _EveryNode:
  67. every.append(fixer)
  68. else:
  69. for node_type in heads:
  70. head_nodes[node_type].append(fixer)
  71. else:
  72. if fixer._accept_type is not None:
  73. head_nodes[fixer._accept_type].append(fixer)
  74. else:
  75. every.append(fixer)
  76. for node_type in chain(pygram.python_grammar.symbol2number.values(),
  77. pygram.python_grammar.tokens):
  78. head_nodes[node_type].extend(every)
  79. return dict(head_nodes)
  80. def get_fixers_from_package(pkg_name):
  81. """
  82. Return the fully qualified names for fixers in the package pkg_name.
  83. """
  84. return [pkg_name + "." + fix_name
  85. for fix_name in get_all_fix_names(pkg_name, False)]
  86. def _identity(obj):
  87. return obj
  88. def _detect_future_features(source):
  89. have_docstring = False
  90. gen = tokenize.generate_tokens(io.StringIO(source).readline)
  91. def advance():
  92. tok = next(gen)
  93. return tok[0], tok[1]
  94. ignore = frozenset({token.NEWLINE, tokenize.NL, token.COMMENT})
  95. features = set()
  96. try:
  97. while True:
  98. tp, value = advance()
  99. if tp in ignore:
  100. continue
  101. elif tp == token.STRING:
  102. if have_docstring:
  103. break
  104. have_docstring = True
  105. elif tp == token.NAME and value == "from":
  106. tp, value = advance()
  107. if tp != token.NAME or value != "__future__":
  108. break
  109. tp, value = advance()
  110. if tp != token.NAME or value != "import":
  111. break
  112. tp, value = advance()
  113. if tp == token.OP and value == "(":
  114. tp, value = advance()
  115. while tp == token.NAME:
  116. features.add(value)
  117. tp, value = advance()
  118. if tp != token.OP or value != ",":
  119. break
  120. tp, value = advance()
  121. else:
  122. break
  123. except StopIteration:
  124. pass
  125. return frozenset(features)
  126. class FixerError(Exception):
  127. """A fixer could not be loaded."""
  128. class RefactoringTool(object):
  129. _default_options = {"print_function" : False,
  130. "write_unchanged_files" : False}
  131. CLASS_PREFIX = "Fix" # The prefix for fixer classes
  132. FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
  133. def __init__(self, fixer_names, options=None, explicit=None):
  134. """Initializer.
  135. Args:
  136. fixer_names: a list of fixers to import
  137. options: a dict with configuration.
  138. explicit: a list of fixers to run even if they are explicit.
  139. """
  140. self.fixers = fixer_names
  141. self.explicit = explicit or []
  142. self.options = self._default_options.copy()
  143. if options is not None:
  144. self.options.update(options)
  145. if self.options["print_function"]:
  146. self.grammar = pygram.python_grammar_no_print_statement
  147. else:
  148. self.grammar = pygram.python_grammar
  149. # When this is True, the refactor*() methods will call write_file() for
  150. # files processed even if they were not changed during refactoring. If
  151. # and only if the refactor method's write parameter was True.
  152. self.write_unchanged_files = self.options.get("write_unchanged_files")
  153. self.errors = []
  154. self.logger = logging.getLogger("RefactoringTool")
  155. self.fixer_log = []
  156. self.wrote = False
  157. self.driver = driver.Driver(self.grammar,
  158. convert=pytree.convert,
  159. logger=self.logger)
  160. self.pre_order, self.post_order = self.get_fixers()
  161. self.files = [] # List of files that were or should be modified
  162. self.BM = bm.BottomMatcher()
  163. self.bmi_pre_order = [] # Bottom Matcher incompatible fixers
  164. self.bmi_post_order = []
  165. for fixer in chain(self.post_order, self.pre_order):
  166. if fixer.BM_compatible:
  167. self.BM.add_fixer(fixer)
  168. # remove fixers that will be handled by the bottom-up
  169. # matcher
  170. elif fixer in self.pre_order:
  171. self.bmi_pre_order.append(fixer)
  172. elif fixer in self.post_order:
  173. self.bmi_post_order.append(fixer)
  174. self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order)
  175. self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order)
  176. def get_fixers(self):
  177. """Inspects the options to load the requested patterns and handlers.
  178. Returns:
  179. (pre_order, post_order), where pre_order is the list of fixers that
  180. want a pre-order AST traversal, and post_order is the list that want
  181. post-order traversal.
  182. """
  183. pre_order_fixers = []
  184. post_order_fixers = []
  185. for fix_mod_path in self.fixers:
  186. mod = __import__(fix_mod_path, {}, {}, ["*"])
  187. fix_name = fix_mod_path.rsplit(".", 1)[-1]
  188. if fix_name.startswith(self.FILE_PREFIX):
  189. fix_name = fix_name[len(self.FILE_PREFIX):]
  190. parts = fix_name.split("_")
  191. class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
  192. try:
  193. fix_class = getattr(mod, class_name)
  194. except AttributeError:
  195. raise FixerError("Can't find %s.%s" % (fix_name, class_name)) from None
  196. fixer = fix_class(self.options, self.fixer_log)
  197. if fixer.explicit and self.explicit is not True and \
  198. fix_mod_path not in self.explicit:
  199. self.log_message("Skipping optional fixer: %s", fix_name)
  200. continue
  201. self.log_debug("Adding transformation: %s", fix_name)
  202. if fixer.order == "pre":
  203. pre_order_fixers.append(fixer)
  204. elif fixer.order == "post":
  205. post_order_fixers.append(fixer)
  206. else:
  207. raise FixerError("Illegal fixer order: %r" % fixer.order)
  208. key_func = operator.attrgetter("run_order")
  209. pre_order_fixers.sort(key=key_func)
  210. post_order_fixers.sort(key=key_func)
  211. return (pre_order_fixers, post_order_fixers)
  212. def log_error(self, msg, *args, **kwds):
  213. """Called when an error occurs."""
  214. raise
  215. def log_message(self, msg, *args):
  216. """Hook to log a message."""
  217. if args:
  218. msg = msg % args
  219. self.logger.info(msg)
  220. def log_debug(self, msg, *args):
  221. if args:
  222. msg = msg % args
  223. self.logger.debug(msg)
  224. def print_output(self, old_text, new_text, filename, equal):
  225. """Called with the old version, new version, and filename of a
  226. refactored file."""
  227. pass
  228. def refactor(self, items, write=False, doctests_only=False):
  229. """Refactor a list of files and directories."""
  230. for dir_or_file in items:
  231. if os.path.isdir(dir_or_file):
  232. self.refactor_dir(dir_or_file, write, doctests_only)
  233. else:
  234. self.refactor_file(dir_or_file, write, doctests_only)
  235. def refactor_dir(self, dir_name, write=False, doctests_only=False):
  236. """Descends down a directory and refactor every Python file found.
  237. Python files are assumed to have a .py extension.
  238. Files and subdirectories starting with '.' are skipped.
  239. """
  240. py_ext = os.extsep + "py"
  241. for dirpath, dirnames, filenames in os.walk(dir_name):
  242. self.log_debug("Descending into %s", dirpath)
  243. dirnames.sort()
  244. filenames.sort()
  245. for name in filenames:
  246. if (not name.startswith(".") and
  247. os.path.splitext(name)[1] == py_ext):
  248. fullname = os.path.join(dirpath, name)
  249. self.refactor_file(fullname, write, doctests_only)
  250. # Modify dirnames in-place to remove subdirs with leading dots
  251. dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
  252. def _read_python_source(self, filename):
  253. """
  254. Do our best to decode a Python source file correctly.
  255. """
  256. try:
  257. f = open(filename, "rb")
  258. except OSError as err:
  259. self.log_error("Can't open %s: %s", filename, err)
  260. return None, None
  261. try:
  262. encoding = tokenize.detect_encoding(f.readline)[0]
  263. finally:
  264. f.close()
  265. with io.open(filename, "r", encoding=encoding, newline='') as f:
  266. return f.read(), encoding
  267. def refactor_file(self, filename, write=False, doctests_only=False):
  268. """Refactors a file."""
  269. input, encoding = self._read_python_source(filename)
  270. if input is None:
  271. # Reading the file failed.
  272. return
  273. input += "\n" # Silence certain parse errors
  274. if doctests_only:
  275. self.log_debug("Refactoring doctests in %s", filename)
  276. output = self.refactor_docstring(input, filename)
  277. if self.write_unchanged_files or output != input:
  278. self.processed_file(output, filename, input, write, encoding)
  279. else:
  280. self.log_debug("No doctest changes in %s", filename)
  281. else:
  282. tree = self.refactor_string(input, filename)
  283. if self.write_unchanged_files or (tree and tree.was_changed):
  284. # The [:-1] is to take off the \n we added earlier
  285. self.processed_file(str(tree)[:-1], filename,
  286. write=write, encoding=encoding)
  287. else:
  288. self.log_debug("No changes in %s", filename)
  289. def refactor_string(self, data, name):
  290. """Refactor a given input string.
  291. Args:
  292. data: a string holding the code to be refactored.
  293. name: a human-readable name for use in error/log messages.
  294. Returns:
  295. An AST corresponding to the refactored input stream; None if
  296. there were errors during the parse.
  297. """
  298. features = _detect_future_features(data)
  299. if "print_function" in features:
  300. self.driver.grammar = pygram.python_grammar_no_print_statement
  301. try:
  302. tree = self.driver.parse_string(data)
  303. except Exception as err:
  304. self.log_error("Can't parse %s: %s: %s",
  305. name, err.__class__.__name__, err)
  306. return
  307. finally:
  308. self.driver.grammar = self.grammar
  309. tree.future_features = features
  310. self.log_debug("Refactoring %s", name)
  311. self.refactor_tree(tree, name)
  312. return tree
  313. def refactor_stdin(self, doctests_only=False):
  314. input = sys.stdin.read()
  315. if doctests_only:
  316. self.log_debug("Refactoring doctests in stdin")
  317. output = self.refactor_docstring(input, "<stdin>")
  318. if self.write_unchanged_files or output != input:
  319. self.processed_file(output, "<stdin>", input)
  320. else:
  321. self.log_debug("No doctest changes in stdin")
  322. else:
  323. tree = self.refactor_string(input, "<stdin>")
  324. if self.write_unchanged_files or (tree and tree.was_changed):
  325. self.processed_file(str(tree), "<stdin>", input)
  326. else:
  327. self.log_debug("No changes in stdin")
  328. def refactor_tree(self, tree, name):
  329. """Refactors a parse tree (modifying the tree in place).
  330. For compatible patterns the bottom matcher module is
  331. used. Otherwise the tree is traversed node-to-node for
  332. matches.
  333. Args:
  334. tree: a pytree.Node instance representing the root of the tree
  335. to be refactored.
  336. name: a human-readable name for this tree.
  337. Returns:
  338. True if the tree was modified, False otherwise.
  339. """
  340. for fixer in chain(self.pre_order, self.post_order):
  341. fixer.start_tree(tree, name)
  342. #use traditional matching for the incompatible fixers
  343. self.traverse_by(self.bmi_pre_order_heads, tree.pre_order())
  344. self.traverse_by(self.bmi_post_order_heads, tree.post_order())
  345. # obtain a set of candidate nodes
  346. match_set = self.BM.run(tree.leaves())
  347. while any(match_set.values()):
  348. for fixer in self.BM.fixers:
  349. if fixer in match_set and match_set[fixer]:
  350. #sort by depth; apply fixers from bottom(of the AST) to top
  351. match_set[fixer].sort(key=pytree.Base.depth, reverse=True)
  352. if fixer.keep_line_order:
  353. #some fixers(eg fix_imports) must be applied
  354. #with the original file's line order
  355. match_set[fixer].sort(key=pytree.Base.get_lineno)
  356. for node in list(match_set[fixer]):
  357. if node in match_set[fixer]:
  358. match_set[fixer].remove(node)
  359. try:
  360. find_root(node)
  361. except ValueError:
  362. # this node has been cut off from a
  363. # previous transformation ; skip
  364. continue
  365. if node.fixers_applied and fixer in node.fixers_applied:
  366. # do not apply the same fixer again
  367. continue
  368. results = fixer.match(node)
  369. if results:
  370. new = fixer.transform(node, results)
  371. if new is not None:
  372. node.replace(new)
  373. #new.fixers_applied.append(fixer)
  374. for node in new.post_order():
  375. # do not apply the fixer again to
  376. # this or any subnode
  377. if not node.fixers_applied:
  378. node.fixers_applied = []
  379. node.fixers_applied.append(fixer)
  380. # update the original match set for
  381. # the added code
  382. new_matches = self.BM.run(new.leaves())
  383. for fxr in new_matches:
  384. if not fxr in match_set:
  385. match_set[fxr]=[]
  386. match_set[fxr].extend(new_matches[fxr])
  387. for fixer in chain(self.pre_order, self.post_order):
  388. fixer.finish_tree(tree, name)
  389. return tree.was_changed
  390. def traverse_by(self, fixers, traversal):
  391. """Traverse an AST, applying a set of fixers to each node.
  392. This is a helper method for refactor_tree().
  393. Args:
  394. fixers: a list of fixer instances.
  395. traversal: a generator that yields AST nodes.
  396. Returns:
  397. None
  398. """
  399. if not fixers:
  400. return
  401. for node in traversal:
  402. for fixer in fixers[node.type]:
  403. results = fixer.match(node)
  404. if results:
  405. new = fixer.transform(node, results)
  406. if new is not None:
  407. node.replace(new)
  408. node = new
  409. def processed_file(self, new_text, filename, old_text=None, write=False,
  410. encoding=None):
  411. """
  412. Called when a file has been refactored and there may be changes.
  413. """
  414. self.files.append(filename)
  415. if old_text is None:
  416. old_text = self._read_python_source(filename)[0]
  417. if old_text is None:
  418. return
  419. equal = old_text == new_text
  420. self.print_output(old_text, new_text, filename, equal)
  421. if equal:
  422. self.log_debug("No changes to %s", filename)
  423. if not self.write_unchanged_files:
  424. return
  425. if write:
  426. self.write_file(new_text, filename, old_text, encoding)
  427. else:
  428. self.log_debug("Not writing changes to %s", filename)
  429. def write_file(self, new_text, filename, old_text, encoding=None):
  430. """Writes a string to a file.
  431. It first shows a unified diff between the old text and the new text, and
  432. then rewrites the file; the latter is only done if the write option is
  433. set.
  434. """
  435. try:
  436. fp = io.open(filename, "w", encoding=encoding, newline='')
  437. except OSError as err:
  438. self.log_error("Can't create %s: %s", filename, err)
  439. return
  440. with fp:
  441. try:
  442. fp.write(new_text)
  443. except OSError as err:
  444. self.log_error("Can't write %s: %s", filename, err)
  445. self.log_debug("Wrote changes to %s", filename)
  446. self.wrote = True
  447. PS1 = ">>> "
  448. PS2 = "... "
  449. def refactor_docstring(self, input, filename):
  450. """Refactors a docstring, looking for doctests.
  451. This returns a modified version of the input string. It looks
  452. for doctests, which start with a ">>>" prompt, and may be
  453. continued with "..." prompts, as long as the "..." is indented
  454. the same as the ">>>".
  455. (Unfortunately we can't use the doctest module's parser,
  456. since, like most parsers, it is not geared towards preserving
  457. the original source.)
  458. """
  459. result = []
  460. block = None
  461. block_lineno = None
  462. indent = None
  463. lineno = 0
  464. for line in input.splitlines(keepends=True):
  465. lineno += 1
  466. if line.lstrip().startswith(self.PS1):
  467. if block is not None:
  468. result.extend(self.refactor_doctest(block, block_lineno,
  469. indent, filename))
  470. block_lineno = lineno
  471. block = [line]
  472. i = line.find(self.PS1)
  473. indent = line[:i]
  474. elif (indent is not None and
  475. (line.startswith(indent + self.PS2) or
  476. line == indent + self.PS2.rstrip() + "\n")):
  477. block.append(line)
  478. else:
  479. if block is not None:
  480. result.extend(self.refactor_doctest(block, block_lineno,
  481. indent, filename))
  482. block = None
  483. indent = None
  484. result.append(line)
  485. if block is not None:
  486. result.extend(self.refactor_doctest(block, block_lineno,
  487. indent, filename))
  488. return "".join(result)
  489. def refactor_doctest(self, block, lineno, indent, filename):
  490. """Refactors one doctest.
  491. A doctest is given as a block of lines, the first of which starts
  492. with ">>>" (possibly indented), while the remaining lines start
  493. with "..." (identically indented).
  494. """
  495. try:
  496. tree = self.parse_block(block, lineno, indent)
  497. except Exception as err:
  498. if self.logger.isEnabledFor(logging.DEBUG):
  499. for line in block:
  500. self.log_debug("Source: %s", line.rstrip("\n"))
  501. self.log_error("Can't parse docstring in %s line %s: %s: %s",
  502. filename, lineno, err.__class__.__name__, err)
  503. return block
  504. if self.refactor_tree(tree, filename):
  505. new = str(tree).splitlines(keepends=True)
  506. # Undo the adjustment of the line numbers in wrap_toks() below.
  507. clipped, new = new[:lineno-1], new[lineno-1:]
  508. assert clipped == ["\n"] * (lineno-1), clipped
  509. if not new[-1].endswith("\n"):
  510. new[-1] += "\n"
  511. block = [indent + self.PS1 + new.pop(0)]
  512. if new:
  513. block += [indent + self.PS2 + line for line in new]
  514. return block
  515. def summarize(self):
  516. if self.wrote:
  517. were = "were"
  518. else:
  519. were = "need to be"
  520. if not self.files:
  521. self.log_message("No files %s modified.", were)
  522. else:
  523. self.log_message("Files that %s modified:", were)
  524. for file in self.files:
  525. self.log_message(file)
  526. if self.fixer_log:
  527. self.log_message("Warnings/messages while refactoring:")
  528. for message in self.fixer_log:
  529. self.log_message(message)
  530. if self.errors:
  531. if len(self.errors) == 1:
  532. self.log_message("There was 1 error:")
  533. else:
  534. self.log_message("There were %d errors:", len(self.errors))
  535. for msg, args, kwds in self.errors:
  536. self.log_message(msg, *args, **kwds)
  537. def parse_block(self, block, lineno, indent):
  538. """Parses a block into a tree.
  539. This is necessary to get correct line number / offset information
  540. in the parser diagnostics and embedded into the parse tree.
  541. """
  542. tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
  543. tree.future_features = frozenset()
  544. return tree
  545. def wrap_toks(self, block, lineno, indent):
  546. """Wraps a tokenize stream to systematically modify start/end."""
  547. tokens = tokenize.generate_tokens(self.gen_lines(block, indent).__next__)
  548. for type, value, (line0, col0), (line1, col1), line_text in tokens:
  549. line0 += lineno - 1
  550. line1 += lineno - 1
  551. # Don't bother updating the columns; this is too complicated
  552. # since line_text would also have to be updated and it would
  553. # still break for tokens spanning lines. Let the user guess
  554. # that the column numbers for doctests are relative to the
  555. # end of the prompt string (PS1 or PS2).
  556. yield type, value, (line0, col0), (line1, col1), line_text
  557. def gen_lines(self, block, indent):
  558. """Generates lines as expected by tokenize from a list of lines.
  559. This strips the first len(indent + self.PS1) characters off each line.
  560. """
  561. prefix1 = indent + self.PS1
  562. prefix2 = indent + self.PS2
  563. prefix = prefix1
  564. for line in block:
  565. if line.startswith(prefix):
  566. yield line[len(prefix):]
  567. elif line == prefix.rstrip() + "\n":
  568. yield "\n"
  569. else:
  570. raise AssertionError("line=%r, prefix=%r" % (line, prefix))
  571. prefix = prefix2
  572. while True:
  573. yield ""
  574. class MultiprocessingUnsupported(Exception):
  575. pass
  576. class MultiprocessRefactoringTool(RefactoringTool):
  577. def __init__(self, *args, **kwargs):
  578. super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs)
  579. self.queue = None
  580. self.output_lock = None
  581. def refactor(self, items, write=False, doctests_only=False,
  582. num_processes=1):
  583. if num_processes == 1:
  584. return super(MultiprocessRefactoringTool, self).refactor(
  585. items, write, doctests_only)
  586. try:
  587. import multiprocessing
  588. except ImportError:
  589. raise MultiprocessingUnsupported
  590. if self.queue is not None:
  591. raise RuntimeError("already doing multiple processes")
  592. self.queue = multiprocessing.JoinableQueue()
  593. self.output_lock = multiprocessing.Lock()
  594. processes = [multiprocessing.Process(target=self._child)
  595. for i in range(num_processes)]
  596. try:
  597. for p in processes:
  598. p.start()
  599. super(MultiprocessRefactoringTool, self).refactor(items, write,
  600. doctests_only)
  601. finally:
  602. self.queue.join()
  603. for i in range(num_processes):
  604. self.queue.put(None)
  605. for p in processes:
  606. if p.is_alive():
  607. p.join()
  608. self.queue = None
  609. def _child(self):
  610. task = self.queue.get()
  611. while task is not None:
  612. args, kwargs = task
  613. try:
  614. super(MultiprocessRefactoringTool, self).refactor_file(
  615. *args, **kwargs)
  616. finally:
  617. self.queue.task_done()
  618. task = self.queue.get()
  619. def refactor_file(self, *args, **kwargs):
  620. if self.queue is not None:
  621. self.queue.put((args, kwargs))
  622. else:
  623. return super(MultiprocessRefactoringTool, self).refactor_file(
  624. *args, **kwargs)