aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
author alemi <[email protected]>2021-04-27 00:04:24 +0200
committer alemi <[email protected]>2021-04-27 00:04:24 +0200
commit4dc4a848a239eae379f32991d443f6741c3f9865 (patch)
treeb9d46edc7076ef4dbca101e1f107c60f21a9680c
parentda1f50dc63d584d66caced7a30b6516e2b7aabe8 (diff)
Made a wrapper for a sqlite db, maybe can work?sqlite
-rw-r--r--telegithook/bot/connections/__init__.py38
-rw-r--r--telegithook/bot/connections/driver.py67
2 files changed, 67 insertions, 38 deletions
diff --git a/telegithook/bot/connections/__init__.py b/telegithook/bot/connections/__init__.py
deleted file mode 100644
index b6de48a..0000000
--- a/telegithook/bot/connections/__init__.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import os
-import json
-
-class Connections:
- def __init__(self):
- self.path = os.path.join(os.path.dirname(__file__), 'connections.json')
- self.load_data()
-
- def load_data(self):
- try:
- with open(self.path) as f:
- self.data = json.load(f)
- except FileNotFoundError:
- self.data = {}
- self.serialize_data()
-
- def serialize_data(self):
- with open(self.path, 'w') as f:
- json.dump(self.data, f)
-
- def get(self, repo:str) -> list:
- if repo in self.data:
- return self.data[repo]
- return []
-
- def add(self, repo:str, chat_id:int):
- if repo not in self.data:
- self.data[repo] = []
- self.data[repo].append(chat_id)
- self.serialize_data()
-
- def remove(self, repo:str) -> bool:
- if repo in self.data and self.data.pop(repo):
- return True
- return False
-
-
-CONNECTIONS = Connections()
diff --git a/telegithook/bot/connections/driver.py b/telegithook/bot/connections/driver.py
new file mode 100644
index 0000000..74e564f
--- /dev/null
+++ b/telegithook/bot/connections/driver.py
@@ -0,0 +1,67 @@
+import os
+import sqlite3
+
+from typing import Union, List
+
+HERE = os.path.dirname(__file__)
+
+class Driver:
+ """A wrapper around a sqlite3 db with just 1 table holding
+ a one-to-many relation: username/repo -> chat_id"""
+ def __init__(self, name="conn.db"):
+ self.db = sqlite3.connect(os.path.join(HERE, name))
+ tables = self.db.cursor().execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
+ if not any(res[0] == "connections" for res in tables):
+ self.db.cursor().execute("CREATE TABLE connections(repo VARCHAR(255), chat_id INT);")
+
+ def query(self, base:str, *replaces, flat=False):
+ """Handy, query manually for iterations, commit if making edits!"""
+ res = self.db.cursor().execute(base, replaces).fetchall()
+ if flat:
+ return [ el for row in res for el in row ]
+ return res
+
+ def all(self):
+ """Get all relations"""
+ return self.query("SELECT * FROM connections")
+
+ def get(self, repo:str) -> List[int]:
+ """Get list of chat_ids associated to one repo"""
+ assert len(repo) < 255
+ return self.query("SELECT chat_id FROM connections WHERE repo = ?", repo, flat=True)
+
+ def get_repos(self, chat_id:int) -> List[int]:
+ """Get all repos associated to one chat_id"""
+ return self.query("SELECT repo FROM connections WHERE chat_id = ?", chat_id, flat=True)
+
+ def add(self, repo:str, chat_id:int) -> bool:
+ """Add one repo:chat_id row. Will check if already present and return false if not added"""
+ assert len(repo) < 255
+ if chat_id in self.get(repo):
+ return False
+ with self.db:
+ self.query("INSERT INTO connections VALUES (?, ?);", repo, chat_id)
+ self.db.commit()
+ return True
+
+ def remove(self, repo:str, chat_id:int):
+ """Will remove one row repo:chat_id if present"""
+ assert len(repo) < 255
+ with self.db:
+ self.query("DELETE FROM connections WHERE repo = ? AND chat_id = ?", repo, chat_id)
+ self.db.commit()
+
+ def remove_all(self, repo:str):
+ """Will remove all rows associated to a repo if present"""
+ assert len(repo) < 255
+ with self.db:
+ self.query("DELETE FROM connections WHERE repo = ?", repo)
+ self.db.commit()
+
+ def remove_all_chats(self, chat_id:int):
+ """Will remove all rows associated to a chat_id if present"""
+ with self.db:
+ self.query("DELETE FROM connections WHERE chat_id = ?", repo)
+ self.db.commit()
+
+DB = Driver() \ No newline at end of file