diff --git a/demo/db.py b/demo/db.py new file mode 100644 index 0000000..ecf28ba --- /dev/null +++ b/demo/db.py @@ -0,0 +1,104 @@ +import os +import sqlite3 +import json + +# DB Settings (Postgres) +DB_HOST = os.environ.get('DB_HOST', None) +DB_PORT = int(os.environ.get('DB_PORT', '5432')) +DB_NAME = os.environ.get('DB_NAME', 'gkachele') +DB_USER = os.environ.get('DB_USER', 'gkachele') +DB_PASSWORD = os.environ.get('DB_PASSWORD', 'gkachele_pass') + +# SQLite Settings +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +SQLITE_PATH = os.path.join(BASE_DIR, 'database', 'main.db') + +class CursorWrapper: + def __init__(self, cursor, is_sqlite=False): + self._cursor = cursor + self._is_sqlite = is_sqlite + + def execute(self, query, params=None): + try: + if self._is_sqlite: + # Skip postgres-specific maintenance commands + if 'setval' in query or 'pg_get_serial_sequence' in query: + return None + + # Adapt psycopg2-style placeholders (%s) to sqlite (?) + if isinstance(query, str) and '%s' in query: + query = query.replace('%s', '?') + + # SQLite doesn't support IF NOT EXISTS in ALTER TABLE + if 'ALTER TABLE' in query and 'ADD COLUMN' in query and 'IF NOT EXISTS' in query: + query = query.replace('IF NOT EXISTS', '') + + # Handle SERIAL PRIMARY KEY and other PG types + query = query.replace('SERIAL PRIMARY KEY', 'INTEGER PRIMARY KEY AUTOINCREMENT') + if 'IF NOT EXISTS' not in query and 'CREATE TABLE' in query: + query = query.replace('CREATE TABLE', 'CREATE TABLE IF NOT EXISTS') + else: + # Adapt sqlite-style placeholders (?) to psycopg2 (%s) + if isinstance(query, str) and '?' in query: + query = query.replace('?', '%s') + + if params is None: + return self._cursor.execute(query) + return self._cursor.execute(query, params) + except (sqlite3.OperationalError, sqlite3.IntegrityError) as e: + msg = str(e).lower() + if "duplicate column name" in msg or "already exists" in msg: + return None + raise e + + def fetchone(self): + return self._cursor.fetchone() + + def fetchall(self): + return self._cursor.fetchall() + + def __getattr__(self, name): + return getattr(self._cursor, name) + +class ConnectionWrapper: + def __init__(self, conn, is_sqlite=False): + self._conn = conn + self._is_sqlite = is_sqlite + + def cursor(self): + return CursorWrapper(self._conn.cursor(), is_sqlite=self._is_sqlite) + + def commit(self): + return self._conn.commit() + + def close(self): + return self._conn.close() + + def __getattr__(self, name): + return getattr(self._conn, name) + +def get_db(): + if DB_HOST: + try: + import psycopg2 + conn = psycopg2.connect( + host=DB_HOST, + port=DB_PORT, + dbname=DB_NAME, + user=DB_USER, + password=DB_PASSWORD, + ) + return ConnectionWrapper(conn, is_sqlite=False) + except Exception as e: + print(f"Postgres connection failed: {e}. Falling back to SQLite.") + + # Fallback to SQLite + os.makedirs(os.path.dirname(SQLITE_PATH), exist_ok=True) + conn = sqlite3.connect(SQLITE_PATH) + return ConnectionWrapper(conn, is_sqlite=True) + +try: + import psycopg2 + IntegrityError = psycopg2.IntegrityError +except ImportError: + IntegrityError = sqlite3.IntegrityError