Skip to content

Commit

Permalink
replace mariadb with PyMySql
Browse files Browse the repository at this point in the history
  • Loading branch information
vpchung committed Feb 20, 2025
1 parent 69b0b29 commit 8df96d1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,31 @@
import sys
import logging

import mariadb
import pymysql
import pymysql.cursors
import pandas as pd


def connect_to_db(db: str = "challenge_service") -> mariadb.Connection:
def connect_to_db(db: str = "challenge_service") -> pymysql.Connection:
"""Establishes connection to the MariaDB database."""
credentials = {
"host": os.getenv("MARIADB_HOST"),
"port": int(os.getenv("MARIADB_PORT", 3306)),
"user": os.getenv("MARIADB_USER"),
"password": os.getenv("MARIADB_PASSWORD"),
"database": db,
"cursorclass": pymysql.cursors.DictCursor,
}
try:
conn = mariadb.connect(**credentials)
conn = pymysql.connect(**credentials)
logging.info(f"Connected to `{db}` database")
return conn
except mariadb.Error as err:
except pymysql.Error as err:
logging.error(f"Error connecting to the database: {err}")
sys.exit(1)


def get_table(conn: mariadb.Connection, table_name: str) -> pd.DataFrame:
def get_table(conn: pymysql.Connection, table_name: str) -> pd.DataFrame:
"""Returns all records from the specified table."""
query = f"SELECT * FROM {table_name}"
try:
Expand All @@ -33,12 +35,12 @@ def get_table(conn: mariadb.Connection, table_name: str) -> pd.DataFrame:
records = cursor.fetchall()
colnames = [val[0] for val in cursor.description]
return pd.DataFrame(records, columns=colnames)
except mariadb.Error as err:
except pymysql.Error as err:
logging.error(f"Error executing query: {err}")
return pd.DataFrame()


def truncate_table(conn: mariadb.Connection, table_name: str):
def truncate_table(conn: pymysql.Connection, table_name: str):
"""Deletes all rows from the specified table.
Temporarily disables foreign key checks for this operation.
Expand All @@ -50,12 +52,12 @@ def truncate_table(conn: mariadb.Connection, table_name: str):
cursor.execute(f"TRUNCATE TABLE {table_name}")
cursor.execute("SET FOREIGN_KEY_CHECKS = 1")
conn.commit() # Save changes made to table.
except mariadb.Error as err:
except pymysql.Error as err:
logging.error(f"Error truncating: {err}")
conn.rollback() # Revert any changes made to data.


def insert_data(conn: mariadb.Connection, table_name: str, data_df: pd.DataFrame):
def insert_data(conn: pymysql.Connection, table_name: str, data_df: pd.DataFrame):
"""Adds data to the specified table, one row at a time.
This iterative approach allows for logging invalid rows for later review.
Expand All @@ -64,12 +66,12 @@ def insert_data(conn: mariadb.Connection, table_name: str, data_df: pd.DataFrame
with conn.cursor() as cursor:
for _, row in data_df.iterrows():
colnames = ", ".join(row.index)
placeholders = ", ".join(["?"] * len(row))
placeholders = ", ".join(["%s"] * len(row))
query = f"INSERT INTO {table_name} ({colnames}) VALUES ({placeholders})"
try:
cursor.execute(query, tuple(row))
conn.commit()
except (mariadb.IntegrityError, mariadb.DataError) as err:
except (pymysql.IntegrityError, pymysql.DataError) as err:
id_colname = "id" if row.get("id") else "challenge_id"
id_value = row.get("id", row.get("challenge_id"))
logging.error(
Expand All @@ -78,12 +80,12 @@ def insert_data(conn: mariadb.Connection, table_name: str, data_df: pd.DataFrame
+ f" → Error: {err}"
)
conn.rollback()
except mariadb.Error as err:
except pymysql.Error as err:
logging.error(f"Error adding row to table `{table_name}`: {err}")
conn.rollback()


def update_table(conn: mariadb.Connection, table_name: str, data: pd.DataFrame):
def update_table(conn: pymysql.Connection, table_name: str, data: pd.DataFrame):
"""Updates the specified table."""
truncate_table(conn, table_name)
insert_data(conn, table_name, data)
2 changes: 1 addition & 1 deletion apps/openchallenges/data-lambda/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"gspread==6.1.4",
"pandas==2.2.3",
"numpy==2.1.0",
"mariadb>=1.1.12",
"pymysql>=1.1.1",
]
name = "openchallenges-data-lambda"
version = "0.1.0"
Expand Down
35 changes: 11 additions & 24 deletions apps/openchallenges/data-lambda/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8df96d1

Please sign in to comment.