Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tasks/libs/common/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,57 @@ def get_main_parent_commit(ctx) -> str:
return get_common_ancestor(ctx, "HEAD", f'origin/{get_default_branch()}')


def get_ancestor_base_branch(branch_name: str | None = None) -> str:
"""
Get the base branch to use for ancestor calculation.

This function tries to determine the correct base branch by:
1. Using COMPARE_TO_BRANCH environment variable if set (preferred in CI)
2. Falling back to GitHub API to look up the PR's target branch
3. Falling back to get_default_branch() if neither works

This is particularly important for PRs targeting release branches
(e.g., 7.54.x) where we need to find the ancestor from the release
branch, not main.

Args:
branch_name: The branch name to look up via GitHub API. If None, uses
CI_COMMIT_REF_NAME or falls back to the current branch.

Returns:
The base branch name to use for ancestor calculation.
"""
# First, check if COMPARE_TO_BRANCH is set (used in GitLab CI)
compare_to_branch = os.environ.get("COMPARE_TO_BRANCH")
if compare_to_branch:
print(f"Using COMPARE_TO_BRANCH environment variable: {compare_to_branch}")
return compare_to_branch

# Fall back to GitHub API to find the PR's target branch
from tasks.libs.ciproviders.github_api import GithubAPI

if branch_name is None:
branch_name = os.environ.get("CI_COMMIT_REF_NAME") or get_current_branch(Context())

try:
github = GithubAPI()
prs = list(github.get_pr_for_branch(branch_name))

if len(prs) == 0:
print(f"No PR found for branch {branch_name}, using default branch")
return get_default_branch()

if len(prs) > 1:
print(f"Warning: Multiple PRs found for branch {branch_name}, using first PR's base")

base_branch = prs[0].base.ref
print(f"Found PR #{prs[0].number} for branch {branch_name}, target branch: {base_branch}")
return base_branch
except Exception as e:
print(f"Warning: Failed to get PR base branch for {branch_name}: {e}")
return get_default_branch()


def check_base_branch(branch, release_version):
"""
Checks if the given branch is either the default branch or the release branch associated
Expand Down
10 changes: 8 additions & 2 deletions tasks/quality_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tasks.libs.common.color import color_message
from tasks.libs.common.git import (
create_tree,
get_ancestor_base_branch,
get_commit_sha,
get_common_ancestor,
get_current_branch,
Expand Down Expand Up @@ -507,9 +508,14 @@ def parse_and_trigger_gates(ctx, config_path: str = GATE_CONFIG_PATH) -> list[St

# Calculate relative sizes (delta from ancestor) before sending metrics
# This is done for all branches to include delta metrics in Datadog
ancestor = get_common_ancestor(ctx, "HEAD")
# Use get_ancestor_base_branch to correctly handle PRs targeting release branches
base_branch = get_ancestor_base_branch(branch)
# get_common_ancestor is supposed to fetch this but it doesn't, so we do it here explicitly
ctx.run(f"git fetch origin {branch.removeprefix('origin/')}", hide=True)
ctx.run(f"git fetch origin {base_branch.removeprefix('origin/')}", hide=True)
ancestor = get_common_ancestor(ctx, "HEAD", base_branch)
current_commit = get_commit_sha(ctx)
# When on main branch, get_common_ancestor returns HEAD itself since merge-base of HEAD and origin/main
# When on main/release branch, get_common_ancestor returns HEAD itself since merge-base of HEAD and origin/<branch>
# is the current commit. In this case, use the parent commit as the ancestor instead.
if ancestor == current_commit:
ancestor = get_commit_sha(ctx, commit="HEAD~1")
Expand Down
4 changes: 3 additions & 1 deletion tasks/static_quality_gates/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tasks.libs.common.color import color_message
from tasks.libs.common.constants import ORIGIN_CATEGORY, ORIGIN_PRODUCT, ORIGIN_SERVICE
from tasks.libs.common.datadog_api import create_gauge, send_metrics
from tasks.libs.common.git import is_a_release_branch
from tasks.libs.common.utils import get_metric_origin
from tasks.libs.package.size import InfraError, directory_size, extract_package, file_size

Expand Down Expand Up @@ -826,7 +827,8 @@ def generate_metric_reports(self, ctx, filename="static_gate_report.json", branc
json.dump(self.metrics, f)

CI_COMMIT_SHA = os.environ.get("CI_COMMIT_SHA")
if not is_nightly and branch == "main" and CI_COMMIT_SHA:
# Store reports for main and release branches to enable delta calculation for backport PRs
if not is_nightly and (branch == "main" or is_a_release_branch(ctx, branch)) and CI_COMMIT_SHA:
ctx.run(
f"aws s3 cp --only-show-errors --region us-east-1 --sse AES256 {filename} {self.S3_REPORT_PATH}/{CI_COMMIT_SHA}/{filename}",
hide="stdout",
Expand Down
124 changes: 123 additions & 1 deletion tasks/unit_tests/libs/common/git_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import re
import unittest
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch

from invoke import MockContext, Result

from tasks.libs.common.git import (
check_local_branch,
check_uncommitted_changes,
get_ancestor_base_branch,
get_commit_sha,
get_current_branch,
get_full_ref_name,
Expand Down Expand Up @@ -287,3 +288,124 @@ def test_final_and_rc_tag_on_same_commit(self):

commit, _ = get_last_release_tag(c, "baubau", "7.61.*")
self.assertEqual(commit, "45f19a6a26c01dae9fdfce944d3fceae7f4e6498")


class TestGetAncestorBaseBranch(unittest.TestCase):
"""Tests for get_ancestor_base_branch function."""

@patch.dict('os.environ', {'COMPARE_TO_BRANCH': 'main'})
def test_uses_compare_to_branch_when_set(self):
"""Test that COMPARE_TO_BRANCH is used directly when set."""
result = get_ancestor_base_branch('feature/my-branch')

self.assertEqual(result, 'main')

@patch.dict('os.environ', {'COMPARE_TO_BRANCH': '7.54.x'})
def test_uses_compare_to_branch_release(self):
"""Test that COMPARE_TO_BRANCH works for release branches."""
result = get_ancestor_base_branch('feature/backport-fix')

self.assertEqual(result, '7.54.x')

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/my-branch'}, clear=True)
def test_falls_back_to_github_api_when_no_compare_to_branch(self, mock_github_api):
"""Test that GitHub API is used when COMPARE_TO_BRANCH is not set."""
mock_pr = MagicMock()
mock_pr.base.ref = 'main'
mock_pr.number = 12345

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/my-branch')

self.assertEqual(result, 'main')
mock_github_instance.get_pr_for_branch.assert_called_once_with('feature/my-branch')

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/release-fix'}, clear=True)
def test_pr_targeting_release_branch_via_api(self, mock_github_api):
"""Test that PRs targeting release branches return the release branch via API."""
mock_pr = MagicMock()
mock_pr.base.ref = '7.54.x'
mock_pr.number = 54321

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/release-fix')

self.assertEqual(result, '7.54.x')
mock_github_instance.get_pr_for_branch.assert_called_once_with('feature/release-fix')

@patch('tasks.libs.common.git.get_default_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/no-pr'}, clear=True)
def test_no_pr_found_falls_back_to_default(self, mock_github_api, mock_get_default_branch):
"""Test that when no PR is found, we fall back to default branch."""
mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = []
mock_github_api.return_value = mock_github_instance
mock_get_default_branch.return_value = 'main'

result = get_ancestor_base_branch('feature/no-pr')

self.assertEqual(result, 'main')
mock_get_default_branch.assert_called_once()

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/multi-pr'}, clear=True)
def test_multiple_prs_uses_first(self, mock_github_api):
"""Test that when multiple PRs exist, we use the first one's base."""
mock_pr1 = MagicMock()
mock_pr1.base.ref = '7.55.x'
mock_pr1.number = 11111

mock_pr2 = MagicMock()
mock_pr2.base.ref = 'main'
mock_pr2.number = 22222

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr1, mock_pr2]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/multi-pr')

self.assertEqual(result, '7.55.x')

@patch('tasks.libs.common.git.get_default_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/api-error'}, clear=True)
def test_github_api_error_falls_back_to_default(self, mock_github_api, mock_get_default_branch):
"""Test that GitHub API errors fall back to default branch."""
mock_github_api.side_effect = Exception("GitHub API unavailable")
mock_get_default_branch.return_value = 'main'

result = get_ancestor_base_branch('feature/api-error')

self.assertEqual(result, 'main')
mock_get_default_branch.assert_called_once()

@patch('tasks.libs.common.git.get_current_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {}, clear=True)
def test_uses_current_branch_when_no_env_var(self, mock_github_api, mock_get_current_branch):
"""Test that we use get_current_branch when CI_COMMIT_REF_NAME is not set."""
mock_get_current_branch.return_value = 'local-branch'

mock_pr = MagicMock()
mock_pr.base.ref = 'main'
mock_pr.number = 99999

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch()

self.assertEqual(result, 'main')
mock_get_current_branch.assert_called_once()
mock_github_instance.get_pr_for_branch.assert_called_once_with('local-branch')
115 changes: 115 additions & 0 deletions tasks/unit_tests/static_quality_gates_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import os
import tempfile
import unittest
from unittest.mock import MagicMock, mock_open, patch

Expand Down Expand Up @@ -1738,5 +1739,119 @@ def test_extracts_revert_pr_not_original(self):
self.assertEqual(result, "44639")


class TestGenerateMetricReports(unittest.TestCase):
"""Test the generate_metric_reports function for S3 upload behavior."""

def setUp(self):
"""Create a temporary directory for test files."""
self.temp_dir = tempfile.mkdtemp()
self.temp_report_file = os.path.join(self.temp_dir, "static_gate_report.json")

def tearDown(self):
"""Clean up temporary files."""
if os.path.exists(self.temp_report_file):
os.remove(self.temp_report_file)
if os.path.exists(self.temp_dir):
os.rmdir(self.temp_dir)

@patch.dict('os.environ', {'CI_COMMIT_SHA': 'abc123def456'})
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_uploads_to_s3_for_main_branch(self, mock_is_release):
"""Should upload report to S3 when on main branch."""
mock_is_release.return_value = False
handler = GateMetricHandler("main", "dev")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(
run={
f'aws s3 cp --only-show-errors --region us-east-1 --sse AES256 {self.temp_report_file} s3://dd-ci-artefacts-build-stable/datadog-agent/static_quality_gates/abc123def456/{self.temp_report_file}': Result(
"Done"
),
}
)

handler.generate_metric_reports(ctx, filename=self.temp_report_file, branch="main", is_nightly=False)

# Verify S3 upload was called
self.assertEqual(len(ctx.run.call_args_list), 1)

@patch.dict('os.environ', {'CI_COMMIT_SHA': 'abc123def456'})
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_uploads_to_s3_for_release_branch(self, mock_is_release):
"""Should upload report to S3 when on a release branch (e.g., 7.54.x)."""
mock_is_release.return_value = True
handler = GateMetricHandler("7.54.x", "dev")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(
run={
f'aws s3 cp --only-show-errors --region us-east-1 --sse AES256 {self.temp_report_file} s3://dd-ci-artefacts-build-stable/datadog-agent/static_quality_gates/abc123def456/{self.temp_report_file}': Result(
"Done"
),
}
)

handler.generate_metric_reports(ctx, filename=self.temp_report_file, branch="7.54.x", is_nightly=False)

# Verify S3 upload was called
self.assertEqual(len(ctx.run.call_args_list), 1)

@patch.dict('os.environ', {'CI_COMMIT_SHA': 'abc123def456'})
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_no_upload_for_feature_branch(self, mock_is_release):
"""Should NOT upload report to S3 when on a feature branch."""
mock_is_release.return_value = False
handler = GateMetricHandler("feature/my-branch", "dev")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(run={})

handler.generate_metric_reports(
ctx, filename=self.temp_report_file, branch="feature/my-branch", is_nightly=False
)

# Verify S3 upload was NOT called
self.assertEqual(len(ctx.run.call_args_list), 0)

@patch.dict('os.environ', {'CI_COMMIT_SHA': 'abc123def456'})
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_no_upload_for_nightly_main(self, mock_is_release):
"""Should NOT upload report to S3 for nightly builds even on main."""
mock_is_release.return_value = False
handler = GateMetricHandler("main", "nightly")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(run={})

handler.generate_metric_reports(ctx, filename=self.temp_report_file, branch="main", is_nightly=True)

# Verify S3 upload was NOT called
self.assertEqual(len(ctx.run.call_args_list), 0)

@patch.dict('os.environ', {'CI_COMMIT_SHA': 'abc123def456'})
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_no_upload_for_nightly_release(self, mock_is_release):
"""Should NOT upload report to S3 for nightly builds on release branches."""
mock_is_release.return_value = True
handler = GateMetricHandler("7.54.x", "nightly")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(run={})

handler.generate_metric_reports(ctx, filename=self.temp_report_file, branch="7.54.x", is_nightly=True)

# Verify S3 upload was NOT called
self.assertEqual(len(ctx.run.call_args_list), 0)

@patch.dict('os.environ', {}, clear=True)
@patch('tasks.static_quality_gates.gates.is_a_release_branch')
def test_no_upload_without_commit_sha(self, mock_is_release):
"""Should NOT upload report to S3 when CI_COMMIT_SHA is not set."""
mock_is_release.return_value = False
handler = GateMetricHandler("main", "dev")
handler.metrics = {"test_gate": {"current_on_disk_size": 100}}
ctx = MockContext(run={})

handler.generate_metric_reports(ctx, filename=self.temp_report_file, branch="main", is_nightly=False)

# Verify S3 upload was NOT called
self.assertEqual(len(ctx.run.call_args_list), 0)


if __name__ == '__main__':
unittest.main()