userfunctions.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. #-*- coding: iso-8859-1 -*-
  2. # pysqlite2/test/userfunctions.py: tests for user-defined functions and
  3. # aggregates.
  4. #
  5. # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
  6. #
  7. # This file is part of pysqlite.
  8. #
  9. # This software is provided 'as-is', without any express or implied
  10. # warranty. In no event will the authors be held liable for any damages
  11. # arising from the use of this software.
  12. #
  13. # Permission is granted to anyone to use this software for any purpose,
  14. # including commercial applications, and to alter it and redistribute it
  15. # freely, subject to the following restrictions:
  16. #
  17. # 1. The origin of this software must not be misrepresented; you must not
  18. # claim that you wrote the original software. If you use this software
  19. # in a product, an acknowledgment in the product documentation would be
  20. # appreciated but is not required.
  21. # 2. Altered source versions must be plainly marked as such, and must not be
  22. # misrepresented as being the original software.
  23. # 3. This notice may not be removed or altered from any source distribution.
  24. import unittest
  25. import sqlite3 as sqlite
  26. def func_returntext():
  27. return "foo"
  28. def func_returnunicode():
  29. return "bar"
  30. def func_returnint():
  31. return 42
  32. def func_returnfloat():
  33. return 3.14
  34. def func_returnnull():
  35. return None
  36. def func_returnblob():
  37. return b"blob"
  38. def func_returnlonglong():
  39. return 1<<31
  40. def func_raiseexception():
  41. 5/0
  42. def func_isstring(v):
  43. return type(v) is str
  44. def func_isint(v):
  45. return type(v) is int
  46. def func_isfloat(v):
  47. return type(v) is float
  48. def func_isnone(v):
  49. return type(v) is type(None)
  50. def func_isblob(v):
  51. return isinstance(v, (bytes, memoryview))
  52. def func_islonglong(v):
  53. return isinstance(v, int) and v >= 1<<31
  54. def func(*args):
  55. return len(args)
  56. class AggrNoStep:
  57. def __init__(self):
  58. pass
  59. def finalize(self):
  60. return 1
  61. class AggrNoFinalize:
  62. def __init__(self):
  63. pass
  64. def step(self, x):
  65. pass
  66. class AggrExceptionInInit:
  67. def __init__(self):
  68. 5/0
  69. def step(self, x):
  70. pass
  71. def finalize(self):
  72. pass
  73. class AggrExceptionInStep:
  74. def __init__(self):
  75. pass
  76. def step(self, x):
  77. 5/0
  78. def finalize(self):
  79. return 42
  80. class AggrExceptionInFinalize:
  81. def __init__(self):
  82. pass
  83. def step(self, x):
  84. pass
  85. def finalize(self):
  86. 5/0
  87. class AggrCheckType:
  88. def __init__(self):
  89. self.val = None
  90. def step(self, whichType, val):
  91. theType = {"str": str, "int": int, "float": float, "None": type(None),
  92. "blob": bytes}
  93. self.val = int(theType[whichType] is type(val))
  94. def finalize(self):
  95. return self.val
  96. class AggrCheckTypes:
  97. def __init__(self):
  98. self.val = 0
  99. def step(self, whichType, *vals):
  100. theType = {"str": str, "int": int, "float": float, "None": type(None),
  101. "blob": bytes}
  102. for val in vals:
  103. self.val += int(theType[whichType] is type(val))
  104. def finalize(self):
  105. return self.val
  106. class AggrSum:
  107. def __init__(self):
  108. self.val = 0.0
  109. def step(self, val):
  110. self.val += val
  111. def finalize(self):
  112. return self.val
  113. class FunctionTests(unittest.TestCase):
  114. def setUp(self):
  115. self.con = sqlite.connect(":memory:")
  116. self.con.create_function("returntext", 0, func_returntext)
  117. self.con.create_function("returnunicode", 0, func_returnunicode)
  118. self.con.create_function("returnint", 0, func_returnint)
  119. self.con.create_function("returnfloat", 0, func_returnfloat)
  120. self.con.create_function("returnnull", 0, func_returnnull)
  121. self.con.create_function("returnblob", 0, func_returnblob)
  122. self.con.create_function("returnlonglong", 0, func_returnlonglong)
  123. self.con.create_function("raiseexception", 0, func_raiseexception)
  124. self.con.create_function("isstring", 1, func_isstring)
  125. self.con.create_function("isint", 1, func_isint)
  126. self.con.create_function("isfloat", 1, func_isfloat)
  127. self.con.create_function("isnone", 1, func_isnone)
  128. self.con.create_function("isblob", 1, func_isblob)
  129. self.con.create_function("islonglong", 1, func_islonglong)
  130. self.con.create_function("spam", -1, func)
  131. def tearDown(self):
  132. self.con.close()
  133. def CheckFuncErrorOnCreate(self):
  134. with self.assertRaises(sqlite.OperationalError):
  135. self.con.create_function("bla", -100, lambda x: 2*x)
  136. def CheckFuncRefCount(self):
  137. def getfunc():
  138. def f():
  139. return 1
  140. return f
  141. f = getfunc()
  142. globals()["foo"] = f
  143. # self.con.create_function("reftest", 0, getfunc())
  144. self.con.create_function("reftest", 0, f)
  145. cur = self.con.cursor()
  146. cur.execute("select reftest()")
  147. def CheckFuncReturnText(self):
  148. cur = self.con.cursor()
  149. cur.execute("select returntext()")
  150. val = cur.fetchone()[0]
  151. self.assertEqual(type(val), str)
  152. self.assertEqual(val, "foo")
  153. def CheckFuncReturnUnicode(self):
  154. cur = self.con.cursor()
  155. cur.execute("select returnunicode()")
  156. val = cur.fetchone()[0]
  157. self.assertEqual(type(val), str)
  158. self.assertEqual(val, "bar")
  159. def CheckFuncReturnInt(self):
  160. cur = self.con.cursor()
  161. cur.execute("select returnint()")
  162. val = cur.fetchone()[0]
  163. self.assertEqual(type(val), int)
  164. self.assertEqual(val, 42)
  165. def CheckFuncReturnFloat(self):
  166. cur = self.con.cursor()
  167. cur.execute("select returnfloat()")
  168. val = cur.fetchone()[0]
  169. self.assertEqual(type(val), float)
  170. if val < 3.139 or val > 3.141:
  171. self.fail("wrong value")
  172. def CheckFuncReturnNull(self):
  173. cur = self.con.cursor()
  174. cur.execute("select returnnull()")
  175. val = cur.fetchone()[0]
  176. self.assertEqual(type(val), type(None))
  177. self.assertEqual(val, None)
  178. def CheckFuncReturnBlob(self):
  179. cur = self.con.cursor()
  180. cur.execute("select returnblob()")
  181. val = cur.fetchone()[0]
  182. self.assertEqual(type(val), bytes)
  183. self.assertEqual(val, b"blob")
  184. def CheckFuncReturnLongLong(self):
  185. cur = self.con.cursor()
  186. cur.execute("select returnlonglong()")
  187. val = cur.fetchone()[0]
  188. self.assertEqual(val, 1<<31)
  189. def CheckFuncException(self):
  190. cur = self.con.cursor()
  191. with self.assertRaises(sqlite.OperationalError) as cm:
  192. cur.execute("select raiseexception()")
  193. cur.fetchone()
  194. self.assertEqual(str(cm.exception), 'user-defined function raised exception')
  195. def CheckParamString(self):
  196. cur = self.con.cursor()
  197. cur.execute("select isstring(?)", ("foo",))
  198. val = cur.fetchone()[0]
  199. self.assertEqual(val, 1)
  200. def CheckParamInt(self):
  201. cur = self.con.cursor()
  202. cur.execute("select isint(?)", (42,))
  203. val = cur.fetchone()[0]
  204. self.assertEqual(val, 1)
  205. def CheckParamFloat(self):
  206. cur = self.con.cursor()
  207. cur.execute("select isfloat(?)", (3.14,))
  208. val = cur.fetchone()[0]
  209. self.assertEqual(val, 1)
  210. def CheckParamNone(self):
  211. cur = self.con.cursor()
  212. cur.execute("select isnone(?)", (None,))
  213. val = cur.fetchone()[0]
  214. self.assertEqual(val, 1)
  215. def CheckParamBlob(self):
  216. cur = self.con.cursor()
  217. cur.execute("select isblob(?)", (memoryview(b"blob"),))
  218. val = cur.fetchone()[0]
  219. self.assertEqual(val, 1)
  220. def CheckParamLongLong(self):
  221. cur = self.con.cursor()
  222. cur.execute("select islonglong(?)", (1<<42,))
  223. val = cur.fetchone()[0]
  224. self.assertEqual(val, 1)
  225. def CheckAnyArguments(self):
  226. cur = self.con.cursor()
  227. cur.execute("select spam(?, ?)", (1, 2))
  228. val = cur.fetchone()[0]
  229. self.assertEqual(val, 2)
  230. class AggregateTests(unittest.TestCase):
  231. def setUp(self):
  232. self.con = sqlite.connect(":memory:")
  233. cur = self.con.cursor()
  234. cur.execute("""
  235. create table test(
  236. t text,
  237. i integer,
  238. f float,
  239. n,
  240. b blob
  241. )
  242. """)
  243. cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
  244. ("foo", 5, 3.14, None, memoryview(b"blob"),))
  245. self.con.create_aggregate("nostep", 1, AggrNoStep)
  246. self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
  247. self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
  248. self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
  249. self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
  250. self.con.create_aggregate("checkType", 2, AggrCheckType)
  251. self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
  252. self.con.create_aggregate("mysum", 1, AggrSum)
  253. def tearDown(self):
  254. #self.cur.close()
  255. #self.con.close()
  256. pass
  257. def CheckAggrErrorOnCreate(self):
  258. with self.assertRaises(sqlite.OperationalError):
  259. self.con.create_function("bla", -100, AggrSum)
  260. def CheckAggrNoStep(self):
  261. cur = self.con.cursor()
  262. with self.assertRaises(AttributeError) as cm:
  263. cur.execute("select nostep(t) from test")
  264. self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
  265. def CheckAggrNoFinalize(self):
  266. cur = self.con.cursor()
  267. with self.assertRaises(sqlite.OperationalError) as cm:
  268. cur.execute("select nofinalize(t) from test")
  269. val = cur.fetchone()[0]
  270. self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
  271. def CheckAggrExceptionInInit(self):
  272. cur = self.con.cursor()
  273. with self.assertRaises(sqlite.OperationalError) as cm:
  274. cur.execute("select excInit(t) from test")
  275. val = cur.fetchone()[0]
  276. self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
  277. def CheckAggrExceptionInStep(self):
  278. cur = self.con.cursor()
  279. with self.assertRaises(sqlite.OperationalError) as cm:
  280. cur.execute("select excStep(t) from test")
  281. val = cur.fetchone()[0]
  282. self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
  283. def CheckAggrExceptionInFinalize(self):
  284. cur = self.con.cursor()
  285. with self.assertRaises(sqlite.OperationalError) as cm:
  286. cur.execute("select excFinalize(t) from test")
  287. val = cur.fetchone()[0]
  288. self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
  289. def CheckAggrCheckParamStr(self):
  290. cur = self.con.cursor()
  291. cur.execute("select checkType('str', ?)", ("foo",))
  292. val = cur.fetchone()[0]
  293. self.assertEqual(val, 1)
  294. def CheckAggrCheckParamInt(self):
  295. cur = self.con.cursor()
  296. cur.execute("select checkType('int', ?)", (42,))
  297. val = cur.fetchone()[0]
  298. self.assertEqual(val, 1)
  299. def CheckAggrCheckParamsInt(self):
  300. cur = self.con.cursor()
  301. cur.execute("select checkTypes('int', ?, ?)", (42, 24))
  302. val = cur.fetchone()[0]
  303. self.assertEqual(val, 2)
  304. def CheckAggrCheckParamFloat(self):
  305. cur = self.con.cursor()
  306. cur.execute("select checkType('float', ?)", (3.14,))
  307. val = cur.fetchone()[0]
  308. self.assertEqual(val, 1)
  309. def CheckAggrCheckParamNone(self):
  310. cur = self.con.cursor()
  311. cur.execute("select checkType('None', ?)", (None,))
  312. val = cur.fetchone()[0]
  313. self.assertEqual(val, 1)
  314. def CheckAggrCheckParamBlob(self):
  315. cur = self.con.cursor()
  316. cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
  317. val = cur.fetchone()[0]
  318. self.assertEqual(val, 1)
  319. def CheckAggrCheckAggrSum(self):
  320. cur = self.con.cursor()
  321. cur.execute("delete from test")
  322. cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
  323. cur.execute("select mysum(i) from test")
  324. val = cur.fetchone()[0]
  325. self.assertEqual(val, 60)
  326. class AuthorizerTests(unittest.TestCase):
  327. @staticmethod
  328. def authorizer_cb(action, arg1, arg2, dbname, source):
  329. if action != sqlite.SQLITE_SELECT:
  330. return sqlite.SQLITE_DENY
  331. if arg2 == 'c2' or arg1 == 't2':
  332. return sqlite.SQLITE_DENY
  333. return sqlite.SQLITE_OK
  334. def setUp(self):
  335. self.con = sqlite.connect(":memory:")
  336. self.con.executescript("""
  337. create table t1 (c1, c2);
  338. create table t2 (c1, c2);
  339. insert into t1 (c1, c2) values (1, 2);
  340. insert into t2 (c1, c2) values (4, 5);
  341. """)
  342. # For our security test:
  343. self.con.execute("select c2 from t2")
  344. self.con.set_authorizer(self.authorizer_cb)
  345. def tearDown(self):
  346. pass
  347. def test_table_access(self):
  348. with self.assertRaises(sqlite.DatabaseError) as cm:
  349. self.con.execute("select * from t2")
  350. self.assertIn('prohibited', str(cm.exception))
  351. def test_column_access(self):
  352. with self.assertRaises(sqlite.DatabaseError) as cm:
  353. self.con.execute("select c2 from t1")
  354. self.assertIn('prohibited', str(cm.exception))
  355. class AuthorizerRaiseExceptionTests(AuthorizerTests):
  356. @staticmethod
  357. def authorizer_cb(action, arg1, arg2, dbname, source):
  358. if action != sqlite.SQLITE_SELECT:
  359. raise ValueError
  360. if arg2 == 'c2' or arg1 == 't2':
  361. raise ValueError
  362. return sqlite.SQLITE_OK
  363. class AuthorizerIllegalTypeTests(AuthorizerTests):
  364. @staticmethod
  365. def authorizer_cb(action, arg1, arg2, dbname, source):
  366. if action != sqlite.SQLITE_SELECT:
  367. return 0.0
  368. if arg2 == 'c2' or arg1 == 't2':
  369. return 0.0
  370. return sqlite.SQLITE_OK
  371. class AuthorizerLargeIntegerTests(AuthorizerTests):
  372. @staticmethod
  373. def authorizer_cb(action, arg1, arg2, dbname, source):
  374. if action != sqlite.SQLITE_SELECT:
  375. return 2**32
  376. if arg2 == 'c2' or arg1 == 't2':
  377. return 2**32
  378. return sqlite.SQLITE_OK
  379. def suite():
  380. function_suite = unittest.makeSuite(FunctionTests, "Check")
  381. aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
  382. authorizer_suite = unittest.makeSuite(AuthorizerTests)
  383. return unittest.TestSuite((
  384. function_suite,
  385. aggregate_suite,
  386. authorizer_suite,
  387. unittest.makeSuite(AuthorizerRaiseExceptionTests),
  388. unittest.makeSuite(AuthorizerIllegalTypeTests),
  389. unittest.makeSuite(AuthorizerLargeIntegerTests),
  390. ))
  391. def test():
  392. runner = unittest.TextTestRunner()
  393. runner.run(suite())
  394. if __name__ == "__main__":
  395. test()