testwith.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. import unittest
  2. from warnings import catch_warnings
  3. from unittest.test.testmock.support import is_instance
  4. from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
  5. something = sentinel.Something
  6. something_else = sentinel.SomethingElse
  7. class WithTest(unittest.TestCase):
  8. def test_with_statement(self):
  9. with patch('%s.something' % __name__, sentinel.Something2):
  10. self.assertEqual(something, sentinel.Something2, "unpatched")
  11. self.assertEqual(something, sentinel.Something)
  12. def test_with_statement_exception(self):
  13. try:
  14. with patch('%s.something' % __name__, sentinel.Something2):
  15. self.assertEqual(something, sentinel.Something2, "unpatched")
  16. raise Exception('pow')
  17. except Exception:
  18. pass
  19. else:
  20. self.fail("patch swallowed exception")
  21. self.assertEqual(something, sentinel.Something)
  22. def test_with_statement_as(self):
  23. with patch('%s.something' % __name__) as mock_something:
  24. self.assertEqual(something, mock_something, "unpatched")
  25. self.assertTrue(is_instance(mock_something, MagicMock),
  26. "patching wrong type")
  27. self.assertEqual(something, sentinel.Something)
  28. def test_patch_object_with_statement(self):
  29. class Foo(object):
  30. something = 'foo'
  31. original = Foo.something
  32. with patch.object(Foo, 'something'):
  33. self.assertNotEqual(Foo.something, original, "unpatched")
  34. self.assertEqual(Foo.something, original)
  35. def test_with_statement_nested(self):
  36. with catch_warnings(record=True):
  37. with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
  38. self.assertEqual(something, mock_something, "unpatched")
  39. self.assertEqual(something_else, mock_something_else,
  40. "unpatched")
  41. self.assertEqual(something, sentinel.Something)
  42. self.assertEqual(something_else, sentinel.SomethingElse)
  43. def test_with_statement_specified(self):
  44. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  45. self.assertEqual(something, mock_something, "unpatched")
  46. self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
  47. self.assertEqual(something, sentinel.Something)
  48. def testContextManagerMocking(self):
  49. mock = Mock()
  50. mock.__enter__ = Mock()
  51. mock.__exit__ = Mock()
  52. mock.__exit__.return_value = False
  53. with mock as m:
  54. self.assertEqual(m, mock.__enter__.return_value)
  55. mock.__enter__.assert_called_with()
  56. mock.__exit__.assert_called_with(None, None, None)
  57. def test_context_manager_with_magic_mock(self):
  58. mock = MagicMock()
  59. with self.assertRaises(TypeError):
  60. with mock:
  61. 'foo' + 3
  62. mock.__enter__.assert_called_with()
  63. self.assertTrue(mock.__exit__.called)
  64. def test_with_statement_same_attribute(self):
  65. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  66. self.assertEqual(something, mock_something, "unpatched")
  67. with patch('%s.something' % __name__) as mock_again:
  68. self.assertEqual(something, mock_again, "unpatched")
  69. self.assertEqual(something, mock_something,
  70. "restored with wrong instance")
  71. self.assertEqual(something, sentinel.Something, "not restored")
  72. def test_with_statement_imbricated(self):
  73. with patch('%s.something' % __name__) as mock_something:
  74. self.assertEqual(something, mock_something, "unpatched")
  75. with patch('%s.something_else' % __name__) as mock_something_else:
  76. self.assertEqual(something_else, mock_something_else,
  77. "unpatched")
  78. self.assertEqual(something, sentinel.Something)
  79. self.assertEqual(something_else, sentinel.SomethingElse)
  80. def test_dict_context_manager(self):
  81. foo = {}
  82. with patch.dict(foo, {'a': 'b'}):
  83. self.assertEqual(foo, {'a': 'b'})
  84. self.assertEqual(foo, {})
  85. with self.assertRaises(NameError):
  86. with patch.dict(foo, {'a': 'b'}):
  87. self.assertEqual(foo, {'a': 'b'})
  88. raise NameError('Konrad')
  89. self.assertEqual(foo, {})
  90. def test_double_patch_instance_method(self):
  91. class C:
  92. def f(self):
  93. pass
  94. c = C()
  95. with patch.object(c, 'f', autospec=True) as patch1:
  96. with patch.object(c, 'f', autospec=True) as patch2:
  97. c.f()
  98. self.assertEqual(patch2.call_count, 1)
  99. self.assertEqual(patch1.call_count, 0)
  100. c.f()
  101. self.assertEqual(patch1.call_count, 1)
  102. class TestMockOpen(unittest.TestCase):
  103. def test_mock_open(self):
  104. mock = mock_open()
  105. with patch('%s.open' % __name__, mock, create=True) as patched:
  106. self.assertIs(patched, mock)
  107. open('foo')
  108. mock.assert_called_once_with('foo')
  109. def test_mock_open_context_manager(self):
  110. mock = mock_open()
  111. handle = mock.return_value
  112. with patch('%s.open' % __name__, mock, create=True):
  113. with open('foo') as f:
  114. f.read()
  115. expected_calls = [call('foo'), call().__enter__(), call().read(),
  116. call().__exit__(None, None, None)]
  117. self.assertEqual(mock.mock_calls, expected_calls)
  118. self.assertIs(f, handle)
  119. def test_mock_open_context_manager_multiple_times(self):
  120. mock = mock_open()
  121. with patch('%s.open' % __name__, mock, create=True):
  122. with open('foo') as f:
  123. f.read()
  124. with open('bar') as f:
  125. f.read()
  126. expected_calls = [
  127. call('foo'), call().__enter__(), call().read(),
  128. call().__exit__(None, None, None),
  129. call('bar'), call().__enter__(), call().read(),
  130. call().__exit__(None, None, None)]
  131. self.assertEqual(mock.mock_calls, expected_calls)
  132. def test_explicit_mock(self):
  133. mock = MagicMock()
  134. mock_open(mock)
  135. with patch('%s.open' % __name__, mock, create=True) as patched:
  136. self.assertIs(patched, mock)
  137. open('foo')
  138. mock.assert_called_once_with('foo')
  139. def test_read_data(self):
  140. mock = mock_open(read_data='foo')
  141. with patch('%s.open' % __name__, mock, create=True):
  142. h = open('bar')
  143. result = h.read()
  144. self.assertEqual(result, 'foo')
  145. def test_readline_data(self):
  146. # Check that readline will return all the lines from the fake file
  147. # And that once fully consumed, readline will return an empty string.
  148. mock = mock_open(read_data='foo\nbar\nbaz\n')
  149. with patch('%s.open' % __name__, mock, create=True):
  150. h = open('bar')
  151. line1 = h.readline()
  152. line2 = h.readline()
  153. line3 = h.readline()
  154. self.assertEqual(line1, 'foo\n')
  155. self.assertEqual(line2, 'bar\n')
  156. self.assertEqual(line3, 'baz\n')
  157. self.assertEqual(h.readline(), '')
  158. # Check that we properly emulate a file that doesn't end in a newline
  159. mock = mock_open(read_data='foo')
  160. with patch('%s.open' % __name__, mock, create=True):
  161. h = open('bar')
  162. result = h.readline()
  163. self.assertEqual(result, 'foo')
  164. self.assertEqual(h.readline(), '')
  165. def test_dunder_iter_data(self):
  166. # Check that dunder_iter will return all the lines from the fake file.
  167. mock = mock_open(read_data='foo\nbar\nbaz\n')
  168. with patch('%s.open' % __name__, mock, create=True):
  169. h = open('bar')
  170. lines = [l for l in h]
  171. self.assertEqual(lines[0], 'foo\n')
  172. self.assertEqual(lines[1], 'bar\n')
  173. self.assertEqual(lines[2], 'baz\n')
  174. self.assertEqual(h.readline(), '')
  175. with self.assertRaises(StopIteration):
  176. next(h)
  177. def test_next_data(self):
  178. # Check that next will correctly return the next available
  179. # line and plays well with the dunder_iter part.
  180. mock = mock_open(read_data='foo\nbar\nbaz\n')
  181. with patch('%s.open' % __name__, mock, create=True):
  182. h = open('bar')
  183. line1 = next(h)
  184. line2 = next(h)
  185. lines = [l for l in h]
  186. self.assertEqual(line1, 'foo\n')
  187. self.assertEqual(line2, 'bar\n')
  188. self.assertEqual(lines[0], 'baz\n')
  189. self.assertEqual(h.readline(), '')
  190. def test_readlines_data(self):
  191. # Test that emulating a file that ends in a newline character works
  192. mock = mock_open(read_data='foo\nbar\nbaz\n')
  193. with patch('%s.open' % __name__, mock, create=True):
  194. h = open('bar')
  195. result = h.readlines()
  196. self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
  197. # Test that files without a final newline will also be correctly
  198. # emulated
  199. mock = mock_open(read_data='foo\nbar\nbaz')
  200. with patch('%s.open' % __name__, mock, create=True):
  201. h = open('bar')
  202. result = h.readlines()
  203. self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
  204. def test_read_bytes(self):
  205. mock = mock_open(read_data=b'\xc6')
  206. with patch('%s.open' % __name__, mock, create=True):
  207. with open('abc', 'rb') as f:
  208. result = f.read()
  209. self.assertEqual(result, b'\xc6')
  210. def test_readline_bytes(self):
  211. m = mock_open(read_data=b'abc\ndef\nghi\n')
  212. with patch('%s.open' % __name__, m, create=True):
  213. with open('abc', 'rb') as f:
  214. line1 = f.readline()
  215. line2 = f.readline()
  216. line3 = f.readline()
  217. self.assertEqual(line1, b'abc\n')
  218. self.assertEqual(line2, b'def\n')
  219. self.assertEqual(line3, b'ghi\n')
  220. def test_readlines_bytes(self):
  221. m = mock_open(read_data=b'abc\ndef\nghi\n')
  222. with patch('%s.open' % __name__, m, create=True):
  223. with open('abc', 'rb') as f:
  224. result = f.readlines()
  225. self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
  226. def test_mock_open_read_with_argument(self):
  227. # At one point calling read with an argument was broken
  228. # for mocks returned by mock_open
  229. some_data = 'foo\nbar\nbaz'
  230. mock = mock_open(read_data=some_data)
  231. self.assertEqual(mock().read(10), some_data[:10])
  232. self.assertEqual(mock().read(10), some_data[:10])
  233. f = mock()
  234. self.assertEqual(f.read(10), some_data[:10])
  235. self.assertEqual(f.read(10), some_data[10:])
  236. def test_interleaved_reads(self):
  237. # Test that calling read, readline, and readlines pulls data
  238. # sequentially from the data we preload with
  239. mock = mock_open(read_data='foo\nbar\nbaz\n')
  240. with patch('%s.open' % __name__, mock, create=True):
  241. h = open('bar')
  242. line1 = h.readline()
  243. rest = h.readlines()
  244. self.assertEqual(line1, 'foo\n')
  245. self.assertEqual(rest, ['bar\n', 'baz\n'])
  246. mock = mock_open(read_data='foo\nbar\nbaz\n')
  247. with patch('%s.open' % __name__, mock, create=True):
  248. h = open('bar')
  249. line1 = h.readline()
  250. rest = h.read()
  251. self.assertEqual(line1, 'foo\n')
  252. self.assertEqual(rest, 'bar\nbaz\n')
  253. def test_overriding_return_values(self):
  254. mock = mock_open(read_data='foo')
  255. handle = mock()
  256. handle.read.return_value = 'bar'
  257. handle.readline.return_value = 'bar'
  258. handle.readlines.return_value = ['bar']
  259. self.assertEqual(handle.read(), 'bar')
  260. self.assertEqual(handle.readline(), 'bar')
  261. self.assertEqual(handle.readlines(), ['bar'])
  262. # call repeatedly to check that a StopIteration is not propagated
  263. self.assertEqual(handle.readline(), 'bar')
  264. self.assertEqual(handle.readline(), 'bar')
  265. if __name__ == '__main__':
  266. unittest.main()