| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- #!/usr/bin/env python3
- #
- # Copyright (c) 2021 Project CHIP Authors
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """Wrapper and utility functions around sqlite3"""
- import sqlite3
- from typing import List, Optional
- import pandas as pd # type: ignore
- from memdf import Config, ConfigDescription
- CONFIG: ConfigDescription = {
- Config.group_def('database'): {
- 'title': 'database options',
- },
- 'database.file': {
- 'help': 'Sqlite3 file',
- 'metavar': 'FILENAME',
- 'default': None,
- 'argparse': {
- 'alias': ['--db'],
- },
- },
- }
- class Database:
- """Wrapper and utility functions around sqlite3"""
- on_open: Optional[List[str]] = None
- on_writable: Optional[List[str]] = None
- def __init__(self, filename: str, writable: bool = True):
- self.filename = filename
- self.writable = writable
- self.con: Optional[sqlite3.Connection] = None
- def __enter__(self):
- return self.open()
- def __exit__(self, et, ev, traceback):
- self.close()
- return False
- def open(self):
- """Open and initialize the database connection."""
- if not self.con:
- db = 'file:' + self.filename
- if not self.writable:
- db += '?mode=ro'
- self.con = sqlite3.connect(db, uri=True)
- if self.on_open:
- for i in self.on_open:
- self.con.execute(i)
- if self.writable and self.on_writable:
- for i in self.on_writable:
- self.con.execute(i)
- return self
- def close(self):
- if self.con:
- self.con.close()
- self.con = None
- return self
- def connection(self) -> sqlite3.Connection:
- assert self.con
- return self.con
- def execute(self, query, parameters=None):
- if parameters:
- return self.con.execute(query, parameters)
- return self.con.execute(query)
- def commit(self):
- self.con.commit()
- return self
- def store(self, table: str, **kwargs):
- """Insert the data if it does not already exist."""
- q = (f"INSERT INTO {table} ({','.join(kwargs.keys())})"
- f" VALUES ({','.join('?' * len(kwargs))})"
- f" ON CONFLICT DO NOTHING")
- v = list(kwargs.values())
- self.connection().execute(q, v)
- def get_matching(self, table: str, columns: List[str], **kwargs):
- q = (f"SELECT {','.join(columns)} FROM {table}"
- f" WHERE {'=? AND '.join(kwargs.keys())}=?")
- v = list(kwargs.values())
- return self.connection().execute(q, v)
- def get_matching_id(self, table: str, **kwargs):
- cur = self.get_matching(table, ['id'], **kwargs)
- row = cur.fetchone()
- if row:
- return row[0]
- return None
- def store_and_return_id(self, table: str, **kwargs) -> Optional[int]:
- self.store(table, **kwargs)
- return self.get_matching_id(table, **kwargs)
- def data_frame(self, query, parameters=None) -> pd.DataFrame:
- """Return the results of a query as a DataFrame."""
- cur = self.execute(query, parameters)
- columns = [i[0] for i in cur.description]
- df = pd.DataFrame(cur.fetchall(), columns=columns)
- self.commit()
- df.attrs = {'title': query}
- return df
|