diff options
| author | 2021-04-27 00:04:24 +0200 | |
|---|---|---|
| committer | 2021-04-27 00:04:24 +0200 | |
| commit | 4dc4a848a239eae379f32991d443f6741c3f9865 (patch) | |
| tree | b9d46edc7076ef4dbca101e1f107c60f21a9680c | |
| parent | da1f50dc63d584d66caced7a30b6516e2b7aabe8 (diff) | |
Made a wrapper for a sqlite db, maybe can work?sqlite
| -rw-r--r-- | telegithook/bot/connections/__init__.py | 38 | ||||
| -rw-r--r-- | telegithook/bot/connections/driver.py | 67 |
2 files changed, 67 insertions, 38 deletions
diff --git a/telegithook/bot/connections/__init__.py b/telegithook/bot/connections/__init__.py deleted file mode 100644 index b6de48a..0000000 --- a/telegithook/bot/connections/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -import os -import json - -class Connections: - def __init__(self): - self.path = os.path.join(os.path.dirname(__file__), 'connections.json') - self.load_data() - - def load_data(self): - try: - with open(self.path) as f: - self.data = json.load(f) - except FileNotFoundError: - self.data = {} - self.serialize_data() - - def serialize_data(self): - with open(self.path, 'w') as f: - json.dump(self.data, f) - - def get(self, repo:str) -> list: - if repo in self.data: - return self.data[repo] - return [] - - def add(self, repo:str, chat_id:int): - if repo not in self.data: - self.data[repo] = [] - self.data[repo].append(chat_id) - self.serialize_data() - - def remove(self, repo:str) -> bool: - if repo in self.data and self.data.pop(repo): - return True - return False - - -CONNECTIONS = Connections() diff --git a/telegithook/bot/connections/driver.py b/telegithook/bot/connections/driver.py new file mode 100644 index 0000000..74e564f --- /dev/null +++ b/telegithook/bot/connections/driver.py @@ -0,0 +1,67 @@ +import os +import sqlite3 + +from typing import Union, List + +HERE = os.path.dirname(__file__) + +class Driver: + """A wrapper around a sqlite3 db with just 1 table holding + a one-to-many relation: username/repo -> chat_id""" + def __init__(self, name="conn.db"): + self.db = sqlite3.connect(os.path.join(HERE, name)) + tables = self.db.cursor().execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() + if not any(res[0] == "connections" for res in tables): + self.db.cursor().execute("CREATE TABLE connections(repo VARCHAR(255), chat_id INT);") + + def query(self, base:str, *replaces, flat=False): + """Handy, query manually for iterations, commit if making edits!""" + res = self.db.cursor().execute(base, replaces).fetchall() + if flat: + return [ el for row in res for el in row ] + return res + + def all(self): + """Get all relations""" + return self.query("SELECT * FROM connections") + + def get(self, repo:str) -> List[int]: + """Get list of chat_ids associated to one repo""" + assert len(repo) < 255 + return self.query("SELECT chat_id FROM connections WHERE repo = ?", repo, flat=True) + + def get_repos(self, chat_id:int) -> List[int]: + """Get all repos associated to one chat_id""" + return self.query("SELECT repo FROM connections WHERE chat_id = ?", chat_id, flat=True) + + def add(self, repo:str, chat_id:int) -> bool: + """Add one repo:chat_id row. Will check if already present and return false if not added""" + assert len(repo) < 255 + if chat_id in self.get(repo): + return False + with self.db: + self.query("INSERT INTO connections VALUES (?, ?);", repo, chat_id) + self.db.commit() + return True + + def remove(self, repo:str, chat_id:int): + """Will remove one row repo:chat_id if present""" + assert len(repo) < 255 + with self.db: + self.query("DELETE FROM connections WHERE repo = ? AND chat_id = ?", repo, chat_id) + self.db.commit() + + def remove_all(self, repo:str): + """Will remove all rows associated to a repo if present""" + assert len(repo) < 255 + with self.db: + self.query("DELETE FROM connections WHERE repo = ?", repo) + self.db.commit() + + def remove_all_chats(self, chat_id:int): + """Will remove all rows associated to a chat_id if present""" + with self.db: + self.query("DELETE FROM connections WHERE chat_id = ?", repo) + self.db.commit() + +DB = Driver()
\ No newline at end of file |
