Skip to content

Commit 46ee27a

Browse files
[DPE-4179] Add integration test for upgrades and rollbacks (#244)
## Issue We are missing integration tests for upgrades and rollbacks ## Solution Add tests
1 parent 2d1bbd7 commit 46ee27a

File tree

10 files changed

+558
-34
lines changed

10 files changed

+558
-34
lines changed

actions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ resume-upgrade:
1616
- force incompatible upgrade and/or
1717
- continue upgrade if 1+ upgraded units have non-active status
1818
required: []
19+
1920
set-tls-private-key:
2021
description:
2122
Set the private key, which will be used for certificate signing requests (CSR). Run

tests/integration/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2024 Canonical Ltd.
2+
# See LICENSE file for licensing details.
3+
4+
import logging
5+
6+
import pytest
7+
from pytest_operator.plugin import OpsTest
8+
9+
from . import juju_
10+
from .helpers import APPLICATION_DEFAULT_APP_NAME, get_application_name
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
@pytest.fixture
16+
async def continuous_writes(ops_test: OpsTest):
17+
"""Starts continuous writes to the MySQL cluster for a test and clear the writes at the end."""
18+
application_name = get_application_name(ops_test, APPLICATION_DEFAULT_APP_NAME)
19+
20+
application_unit = ops_test.model.applications[application_name].units[0]
21+
22+
logger.info("Clearing continuous writes")
23+
await juju_.run_action(application_unit, "clear-continuous-writes")
24+
25+
logger.info("Starting continuous writes")
26+
await juju_.run_action(application_unit, "start-continuous-writes")
27+
28+
yield
29+
30+
logger.info("Clearing continuous writes")
31+
await juju_.run_action(application_unit, "clear-continuous-writes")

tests/integration/helpers.py

Lines changed: 223 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33

44
import itertools
55
import json
6+
import logging
67
import subprocess
78
import tempfile
89
from typing import Dict, List, Optional
910

1011
import mysql.connector
12+
import tenacity
1113
import yaml
14+
from juju.model import Model
1215
from juju.unit import Unit
1316
from mysql.connector.errors import (
1417
DatabaseError,
@@ -17,16 +20,25 @@
1720
ProgrammingError,
1821
)
1922
from pytest_operator.plugin import OpsTest
20-
from tenacity import Retrying, retry, stop_after_attempt, wait_fixed
2123

2224
from .connector import MySQLConnector
25+
from .juju_ import run_action
26+
27+
logger = logging.getLogger(__name__)
28+
29+
CONTINUOUS_WRITES_DATABASE_NAME = "continuous_writes_database"
30+
CONTINUOUS_WRITES_TABLE_NAME = "data"
31+
32+
MYSQL_DEFAULT_APP_NAME = "mysql-k8s"
33+
MYSQL_ROUTER_DEFAULT_APP_NAME = "mysql-router-k8s"
34+
APPLICATION_DEFAULT_APP_NAME = "mysql-test-app"
2335

2436
SERVER_CONFIG_USERNAME = "serverconfig"
2537
CONTAINER_NAME = "mysql-router"
2638
LOGROTATE_EXECUTOR_SERVICE = "logrotate_executor"
2739

2840

29-
async def execute_queries_on_unit(
41+
async def execute_queries_against_unit(
3042
unit_address: str,
3143
username: str,
3244
password: str,
@@ -75,13 +87,10 @@ async def get_server_config_credentials(unit: Unit) -> Dict:
7587
Returns:
7688
A dictionary with the server config username and password
7789
"""
78-
action = await unit.run_action(action_name="get-password", username=SERVER_CONFIG_USERNAME)
79-
result = await action.wait()
90+
return await run_action(unit, "get-password", username=SERVER_CONFIG_USERNAME)
8091

81-
return result.results
8292

83-
84-
async def get_inserted_data_by_application(unit: Unit) -> str:
93+
async def get_inserted_data_by_application(unit: Unit) -> Optional[str]:
8594
"""Helper to run an action to retrieve inserted data by the application.
8695
8796
Args:
@@ -90,10 +99,7 @@ async def get_inserted_data_by_application(unit: Unit) -> str:
9099
Returns:
91100
A string representing the inserted data
92101
"""
93-
action = await unit.run_action("get-inserted-data")
94-
result = await action.wait()
95-
96-
return result.results.get("data")
102+
return (await run_action(unit, "get-inserted-data")).get("data")
97103

98104

99105
async def get_unit_address(ops_test: OpsTest, unit_name: str) -> str:
@@ -326,7 +332,9 @@ async def stop_running_flush_mysqlrouter_job(ops_test: OpsTest, unit_name: str)
326332
)
327333

328334
# hold execution until process is stopped
329-
for attempt in Retrying(reraise=True, stop=stop_after_attempt(45), wait=wait_fixed(2)):
335+
for attempt in tenacity.Retrying(
336+
reraise=True, stop=tenacity.stop_after_attempt(45), wait=tenacity.wait_fixed(2)
337+
):
330338
with attempt:
331339
if await get_process_pid(ops_test, unit_name, CONTAINER_NAME, "logrotate"):
332340
raise Exception("Failed to stop the flush_mysql_logs logrotate process.")
@@ -362,7 +370,7 @@ async def rotate_mysqlrouter_logs(ops_test: OpsTest, unit_name: str) -> None:
362370
)
363371

364372

365-
@retry(stop=stop_after_attempt(8), wait=wait_fixed(15), reraise=True)
373+
@tenacity.retry(stop=tenacity.stop_after_attempt(8), wait=tenacity.wait_fixed(15), reraise=True)
366374
def is_connection_possible(credentials: Dict, **extra_opts) -> bool:
367375
"""Test a connection to a MySQL server.
368376
@@ -431,3 +439,205 @@ async def get_tls_certificate_issuer(
431439
return_code, issuer, _ = await ops_test.juju(*get_tls_certificate_issuer_commands)
432440
assert return_code == 0, f"failed to get TLS certificate issuer on {unit_name=}"
433441
return issuer
442+
443+
444+
def get_application_name(ops_test: OpsTest, application_name_substring: str) -> str:
445+
"""Returns the name of the application with the provided application name.
446+
447+
This enables us to retrieve the name of the deployed application in an existing model.
448+
449+
Note: if multiple applications with the application name exist,
450+
the first one found will be returned.
451+
"""
452+
for application in ops_test.model.applications:
453+
if application_name_substring == application:
454+
return application
455+
456+
return ""
457+
458+
459+
@tenacity.retry(stop=tenacity.stop_after_attempt(30), wait=tenacity.wait_fixed(5), reraise=True)
460+
async def get_primary_unit(
461+
ops_test: OpsTest,
462+
unit: Unit,
463+
app_name: str,
464+
) -> Unit:
465+
"""Helper to retrieve the primary unit.
466+
467+
Args:
468+
ops_test: The ops test object passed into every test case
469+
unit: A unit on which to run dba.get_cluster().status() on
470+
app_name: The name of the test application
471+
cluster_name: The name of the test cluster
472+
473+
Returns:
474+
A juju unit that is a MySQL primary
475+
"""
476+
units = ops_test.model.applications[app_name].units
477+
results = await run_action(unit, "get-cluster-status")
478+
479+
primary_unit = None
480+
for k, v in results["status"]["defaultreplicaset"]["topology"].items():
481+
if v["memberrole"] == "primary":
482+
unit_name = f"{app_name}/{k.split('-')[-1]}"
483+
primary_unit = [unit for unit in units if unit.name == unit_name][0]
484+
break
485+
486+
if not primary_unit:
487+
raise ValueError("Unable to find primary unit")
488+
return primary_unit
489+
490+
491+
async def get_primary_unit_wrapper(ops_test: OpsTest, app_name: str, unit_excluded=None) -> Unit:
492+
"""Wrapper for getting primary.
493+
494+
Args:
495+
ops_test: The ops test object passed into every test case
496+
app_name: The name of the application
497+
unit_excluded: excluded unit to run command on
498+
Returns:
499+
The primary Unit object
500+
"""
501+
logger.info("Retrieving primary unit")
502+
units = ops_test.model.applications[app_name].units
503+
if unit_excluded:
504+
# if defined, exclude unit from available unit to run command on
505+
# useful when the workload is stopped on unit
506+
unit = ({unit for unit in units if unit.name != unit_excluded.name}).pop()
507+
else:
508+
unit = units[0]
509+
510+
primary_unit = await get_primary_unit(ops_test, unit, app_name)
511+
512+
return primary_unit
513+
514+
515+
async def get_max_written_value_in_database(
516+
ops_test: OpsTest, unit: Unit, credentials: dict
517+
) -> int:
518+
"""Retrieve the max written value in the MySQL database.
519+
520+
Args:
521+
ops_test: The ops test framework
522+
unit: The MySQL unit on which to execute queries on
523+
credentials: Database credentials to use
524+
"""
525+
unit_address = await get_unit_address(ops_test, unit.name)
526+
527+
select_max_written_value_sql = [
528+
f"SELECT MAX(number) FROM `{CONTINUOUS_WRITES_DATABASE_NAME}`.`{CONTINUOUS_WRITES_TABLE_NAME}`;"
529+
]
530+
531+
output = await execute_queries_against_unit(
532+
unit_address,
533+
credentials["username"],
534+
credentials["password"],
535+
select_max_written_value_sql,
536+
)
537+
538+
return output[0]
539+
540+
541+
async def ensure_all_units_continuous_writes_incrementing(
542+
ops_test: OpsTest, mysql_units: Optional[List[Unit]] = None
543+
) -> None:
544+
"""Ensure that continuous writes is incrementing on all units.
545+
546+
Also, ensure that all continuous writes up to the max written value is available
547+
on all units (ensure that no committed data is lost).
548+
"""
549+
logger.info("Ensure continuous writes are incrementing")
550+
551+
mysql_application_name = get_application_name(ops_test, MYSQL_DEFAULT_APP_NAME)
552+
553+
if not mysql_units:
554+
mysql_units = ops_test.model.applications[mysql_application_name].units
555+
556+
primary = await get_primary_unit_wrapper(ops_test, mysql_application_name)
557+
558+
server_config_credentials = await get_server_config_credentials(mysql_units[0])
559+
560+
last_max_written_value = await get_max_written_value_in_database(
561+
ops_test, primary, server_config_credentials
562+
)
563+
564+
select_all_continuous_writes_sql = [
565+
f"SELECT * FROM `{CONTINUOUS_WRITES_DATABASE_NAME}`.`{CONTINUOUS_WRITES_TABLE_NAME}`"
566+
]
567+
568+
async with ops_test.fast_forward():
569+
for unit in mysql_units:
570+
for attempt in tenacity.Retrying(
571+
reraise=True, stop=tenacity.stop_after_delay(5 * 60), wait=tenacity.wait_fixed(10)
572+
):
573+
with attempt:
574+
# ensure that all units are up to date (including the previous primary)
575+
unit_address = await get_unit_address(ops_test, unit.name)
576+
577+
# ensure the max written value is incrementing (continuous writes is active)
578+
max_written_value = await get_max_written_value_in_database(
579+
ops_test, unit, server_config_credentials
580+
)
581+
assert (
582+
max_written_value > last_max_written_value
583+
), "Continuous writes not incrementing"
584+
585+
# ensure that the unit contains all values up to the max written value
586+
all_written_values = set(
587+
await execute_queries_against_unit(
588+
unit_address,
589+
server_config_credentials["username"],
590+
server_config_credentials["password"],
591+
select_all_continuous_writes_sql,
592+
)
593+
)
594+
numbers = {n for n in range(1, max_written_value)}
595+
assert (
596+
numbers <= all_written_values
597+
), f"Missing numbers in database for unit {unit.name}"
598+
599+
last_max_written_value = max_written_value
600+
601+
602+
async def get_workload_version(ops_test: OpsTest, unit_name: str) -> str:
603+
"""Get the workload version of the deployed router charm."""
604+
return_code, output, _ = await ops_test.juju(
605+
"ssh",
606+
unit_name,
607+
"sudo",
608+
"cat",
609+
f"/var/lib/juju/agents/unit-{unit_name.replace('/', '-')}/charm/workload_version",
610+
)
611+
612+
assert return_code == 0
613+
return output.strip()
614+
615+
616+
async def get_leader_unit(
617+
ops_test: Optional[OpsTest], app_name: str, model: Optional[Model] = None
618+
) -> Optional[Unit]:
619+
"""Get the leader unit of a given application.
620+
621+
Args:
622+
ops_test: The ops test framework instance
623+
app_name: The name of the application
624+
model: The model to use (overrides ops_test.model)
625+
"""
626+
leader_unit = None
627+
if not model:
628+
model = ops_test.model
629+
for unit in model.applications[app_name].units:
630+
if await unit.is_leader_from_status():
631+
leader_unit = unit
632+
break
633+
634+
return leader_unit
635+
636+
637+
def get_juju_status(model_name: str) -> str:
638+
"""Return the juju status output.
639+
640+
Args:
641+
model_name: The model for which to retrieve juju status for
642+
"""
643+
return subprocess.check_output(["juju", "status", "--model", model_name]).decode("utf-8")

tests/integration/juju_.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import importlib.metadata
55

6+
import juju.unit
7+
68
# libjuju version != juju agent version, but the major version should be identical—which is good
79
_libjuju_version = importlib.metadata.version("juju")
810
is_3_1_or_higher = (
@@ -12,3 +14,14 @@
1214
)
1315

1416
is_3_or_higher = int(_libjuju_version.split(".")[0]) >= 3
17+
18+
19+
async def run_action(unit: juju.unit.Unit, action_name, **params):
20+
action = await unit.run_action(action_name=action_name, **params)
21+
result = await action.wait()
22+
# Syntax changed across libjuju major versions
23+
if is_3_or_higher:
24+
assert result.results.get("return-code") == 0
25+
else:
26+
assert result.results.get("Code") == "0"
27+
return result.results

tests/integration/test_charm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from pytest_operator.plugin import OpsTest
1313

1414
from .helpers import (
15-
execute_queries_on_unit,
15+
APPLICATION_DEFAULT_APP_NAME,
16+
MYSQL_DEFAULT_APP_NAME,
17+
MYSQL_ROUTER_DEFAULT_APP_NAME,
18+
execute_queries_against_unit,
1619
get_inserted_data_by_application,
1720
get_server_config_credentials,
1821
get_unit_address,
@@ -23,9 +26,9 @@
2326

2427
METADATA = yaml.safe_load(Path("./metadata.yaml").read_text())
2528

26-
MYSQL_APP_NAME = "mysql-k8s"
27-
MYSQL_ROUTER_APP_NAME = "mysql-router-k8s"
28-
APPLICATION_APP_NAME = "mysql-test-app"
29+
MYSQL_APP_NAME = MYSQL_DEFAULT_APP_NAME
30+
MYSQL_ROUTER_APP_NAME = MYSQL_ROUTER_DEFAULT_APP_NAME
31+
APPLICATION_APP_NAME = APPLICATION_DEFAULT_APP_NAME
2932
SLOW_TIMEOUT = 15 * 60
3033
MODEL_CONFIG = {"logging-config": "<root>=INFO;unit=DEBUG"}
3134

@@ -111,7 +114,7 @@ async def test_database_relation(ops_test: OpsTest):
111114
select_inserted_data_sql = [
112115
f"SELECT data FROM continuous_writes_database.random_data WHERE data = '{inserted_data}'",
113116
]
114-
selected_data = await execute_queries_on_unit(
117+
selected_data = await execute_queries_against_unit(
115118
mysql_unit_address,
116119
server_config_credentials["username"],
117120
server_config_credentials["password"],

0 commit comments

Comments
 (0)