Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 56 additions & 9 deletions tasks/libs/common/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_common_ancestor(ctx, branch, base=None, try_fetch=True, hide=True) -> st
ctx: The invoke context.
branch: The branch to get the common ancestor with.
base: The base branch to get the common ancestor with. Defaults to the default branch.
try_fetch: Try to fetch the base branch if it's not found (to avoid S3 caching issues).
try_fetch: Fetch the base/branch refs before computing merge-base to avoid stale S3-cached refs.

Returns:
The common ancestor between two branches.
Expand All @@ -185,19 +185,15 @@ def get_common_ancestor(ctx, branch, base=None, try_fetch=True, hide=True) -> st
base = get_full_ref_name(base)
branch = get_full_ref_name(branch)

try:
return ctx.run(f"git merge-base {branch} {base}", hide=hide).stdout.strip()
except Exception:
if not try_fetch:
raise

# With S3 caching, it's possible that the base branch is not fetched
# With S3 caching, origin refs can be stale. Fetch them proactively to ensure
# we compute the merge-base against the latest remote state.
if try_fetch:
if base.startswith("origin/"):
ctx.run(f"git fetch origin {base.removeprefix('origin/')}", hide=hide)
if branch.startswith("origin/"):
ctx.run(f"git fetch origin {branch.removeprefix('origin/')}", hide=hide)

return ctx.run(f"git merge-base {branch} {base}", hide=hide).stdout.strip()
return ctx.run(f"git merge-base {branch} {base}", hide=hide).stdout.strip()


def check_uncommitted_changes(ctx):
Expand Down Expand Up @@ -231,6 +227,57 @@ def get_main_parent_commit(ctx) -> str:
return get_common_ancestor(ctx, "HEAD", f'origin/{get_default_branch()}')


def get_ancestor_base_branch(branch_name: str | None = None) -> str:
"""
Get the base branch to use for ancestor calculation.

This function tries to determine the correct base branch by:
1. Using COMPARE_TO_BRANCH environment variable if set (preferred in CI)
2. Falling back to GitHub API to look up the PR's target branch
3. Falling back to get_default_branch() if neither works

This is particularly important for PRs targeting release branches
(e.g., 7.54.x) where we need to find the ancestor from the release
branch, not main.

Args:
branch_name: The branch name to look up via GitHub API. If None, uses
CI_COMMIT_REF_NAME or falls back to the current branch.

Returns:
The base branch name to use for ancestor calculation.
"""
# First, check if COMPARE_TO_BRANCH is set (used in GitLab CI)
compare_to_branch = os.environ.get("COMPARE_TO_BRANCH")
if compare_to_branch:
print(f"Using COMPARE_TO_BRANCH environment variable: {compare_to_branch}")
return compare_to_branch

# Fall back to GitHub API to find the PR's target branch
from tasks.libs.ciproviders.github_api import GithubAPI

if branch_name is None:
branch_name = os.environ.get("CI_COMMIT_REF_NAME") or get_current_branch(Context())

try:
github = GithubAPI()
prs = list(github.get_pr_for_branch(branch_name))

if len(prs) == 0:
print(f"No PR found for branch {branch_name}, using default branch")
return get_default_branch()

if len(prs) > 1:
print(f"Warning: Multiple PRs found for branch {branch_name}, using first PR's base")

base_branch = prs[0].base.ref
print(f"Found PR #{prs[0].number} for branch {branch_name}, target branch: {base_branch}")
return base_branch
except Exception as e:
print(f"Warning: Failed to get PR base branch for {branch_name}: {e}")
return get_default_branch()


def check_base_branch(branch, release_version):
"""
Checks if the given branch is either the default branch or the release branch associated
Expand Down
7 changes: 5 additions & 2 deletions tasks/quality_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tasks.libs.common.color import color_message
from tasks.libs.common.git import (
create_tree,
get_ancestor_base_branch,
get_commit_sha,
get_common_ancestor,
get_current_branch,
Expand Down Expand Up @@ -507,9 +508,11 @@ def parse_and_trigger_gates(ctx, config_path: str = GATE_CONFIG_PATH) -> list[St

# Calculate relative sizes (delta from ancestor) before sending metrics
# This is done for all branches to include delta metrics in Datadog
ancestor = get_common_ancestor(ctx, "HEAD")
# Use get_ancestor_base_branch to correctly handle PRs targeting release branches
base_branch = get_ancestor_base_branch(branch)
ancestor = get_common_ancestor(ctx, "HEAD", base_branch)
current_commit = get_commit_sha(ctx)
# When on main branch, get_common_ancestor returns HEAD itself since merge-base of HEAD and origin/main
# When on main/release branch, get_common_ancestor returns HEAD itself since merge-base of HEAD and origin/<branch>
# is the current commit. In this case, use the parent commit as the ancestor instead.
if ancestor == current_commit:
ancestor = get_commit_sha(ctx, commit="HEAD~1")
Expand Down
2 changes: 1 addition & 1 deletion tasks/security_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def build_functional_tests(

arch = Arch.from_str(arch)
ldflags, gcflags, env = get_build_flags(ctx, static=static, arch=arch)
common_ancestor = get_common_ancestor(ctx, "HEAD")
common_ancestor = get_common_ancestor(ctx, "HEAD", try_fetch=False)
print(f"Using git ref {common_ancestor} as common ancestor between HEAD and main branch")
ldflags += f"-X {REPO_PATH}/{srcpath}.GitAncestorOnMain={common_ancestor} "

Expand Down
4 changes: 3 additions & 1 deletion tasks/static_quality_gates/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tasks.libs.common.color import color_message
from tasks.libs.common.constants import ORIGIN_CATEGORY, ORIGIN_PRODUCT, ORIGIN_SERVICE
from tasks.libs.common.datadog_api import create_gauge, send_metrics
from tasks.libs.common.git import is_a_release_branch
from tasks.libs.common.utils import get_metric_origin
from tasks.libs.package.size import InfraError, directory_size, extract_package, file_size

Expand Down Expand Up @@ -826,7 +827,8 @@ def generate_metric_reports(self, ctx, filename="static_gate_report.json", branc
json.dump(self.metrics, f)

CI_COMMIT_SHA = os.environ.get("CI_COMMIT_SHA")
if not is_nightly and branch == "main" and CI_COMMIT_SHA:
# Store reports for main and release branches to enable delta calculation for backport PRs
if not is_nightly and (branch == "main" or is_a_release_branch(ctx, branch)) and CI_COMMIT_SHA:
ctx.run(
f"aws s3 cp --only-show-errors --region us-east-1 --sse AES256 {filename} {self.S3_REPORT_PATH}/{CI_COMMIT_SHA}/{filename}",
hide="stdout",
Expand Down
124 changes: 123 additions & 1 deletion tasks/unit_tests/libs/common/git_tests.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import re
import unittest
from unittest.mock import MagicMock, call
from unittest.mock import MagicMock, call, patch

from invoke import MockContext, Result

from tasks.libs.common.git import (
check_local_branch,
check_uncommitted_changes,
get_ancestor_base_branch,
get_commit_sha,
get_current_branch,
get_full_ref_name,
Expand Down Expand Up @@ -287,3 +288,124 @@ def test_final_and_rc_tag_on_same_commit(self):

commit, _ = get_last_release_tag(c, "baubau", "7.61.*")
self.assertEqual(commit, "45f19a6a26c01dae9fdfce944d3fceae7f4e6498")


class TestGetAncestorBaseBranch(unittest.TestCase):
"""Tests for get_ancestor_base_branch function."""

@patch.dict('os.environ', {'COMPARE_TO_BRANCH': 'main'})
def test_uses_compare_to_branch_when_set(self):
"""Test that COMPARE_TO_BRANCH is used directly when set."""
result = get_ancestor_base_branch('feature/my-branch')

self.assertEqual(result, 'main')

@patch.dict('os.environ', {'COMPARE_TO_BRANCH': '7.54.x'})
def test_uses_compare_to_branch_release(self):
"""Test that COMPARE_TO_BRANCH works for release branches."""
result = get_ancestor_base_branch('feature/backport-fix')

self.assertEqual(result, '7.54.x')

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/my-branch'}, clear=True)
def test_falls_back_to_github_api_when_no_compare_to_branch(self, mock_github_api):
"""Test that GitHub API is used when COMPARE_TO_BRANCH is not set."""
mock_pr = MagicMock()
mock_pr.base.ref = 'main'
mock_pr.number = 12345

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/my-branch')

self.assertEqual(result, 'main')
mock_github_instance.get_pr_for_branch.assert_called_once_with('feature/my-branch')

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/release-fix'}, clear=True)
def test_pr_targeting_release_branch_via_api(self, mock_github_api):
"""Test that PRs targeting release branches return the release branch via API."""
mock_pr = MagicMock()
mock_pr.base.ref = '7.54.x'
mock_pr.number = 54321

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/release-fix')

self.assertEqual(result, '7.54.x')
mock_github_instance.get_pr_for_branch.assert_called_once_with('feature/release-fix')

@patch('tasks.libs.common.git.get_default_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/no-pr'}, clear=True)
def test_no_pr_found_falls_back_to_default(self, mock_github_api, mock_get_default_branch):
"""Test that when no PR is found, we fall back to default branch."""
mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = []
mock_github_api.return_value = mock_github_instance
mock_get_default_branch.return_value = 'main'

result = get_ancestor_base_branch('feature/no-pr')

self.assertEqual(result, 'main')
mock_get_default_branch.assert_called_once()

@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/multi-pr'}, clear=True)
def test_multiple_prs_uses_first(self, mock_github_api):
"""Test that when multiple PRs exist, we use the first one's base."""
mock_pr1 = MagicMock()
mock_pr1.base.ref = '7.55.x'
mock_pr1.number = 11111

mock_pr2 = MagicMock()
mock_pr2.base.ref = 'main'
mock_pr2.number = 22222

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr1, mock_pr2]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch('feature/multi-pr')

self.assertEqual(result, '7.55.x')

@patch('tasks.libs.common.git.get_default_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'feature/api-error'}, clear=True)
def test_github_api_error_falls_back_to_default(self, mock_github_api, mock_get_default_branch):
"""Test that GitHub API errors fall back to default branch."""
mock_github_api.side_effect = Exception("GitHub API unavailable")
mock_get_default_branch.return_value = 'main'

result = get_ancestor_base_branch('feature/api-error')

self.assertEqual(result, 'main')
mock_get_default_branch.assert_called_once()

@patch('tasks.libs.common.git.get_current_branch')
@patch('tasks.libs.ciproviders.github_api.GithubAPI')
@patch.dict('os.environ', {}, clear=True)
def test_uses_current_branch_when_no_env_var(self, mock_github_api, mock_get_current_branch):
"""Test that we use get_current_branch when CI_COMMIT_REF_NAME is not set."""
mock_get_current_branch.return_value = 'local-branch'

mock_pr = MagicMock()
mock_pr.base.ref = 'main'
mock_pr.number = 99999

mock_github_instance = MagicMock()
mock_github_instance.get_pr_for_branch.return_value = [mock_pr]
mock_github_api.return_value = mock_github_instance

result = get_ancestor_base_branch()

self.assertEqual(result, 'main')
mock_get_current_branch.assert_called_once()
mock_github_instance.get_pr_for_branch.assert_called_once_with('local-branch')
15 changes: 12 additions & 3 deletions tasks/unit_tests/package_lib_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,21 @@ def setUp(self) -> None:

@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'puppet'})
def test_found_on_dev(self):
c = MockContext(run={'git merge-base HEAD origin/main': Result('grand_ma')})
c = MockContext(
run={'git fetch origin main': Result(''), 'git merge-base HEAD origin/main': Result('grand_ma')}
)
self.assertEqual(get_ancestor(c, self.package_sizes, False), "grand_ma")

@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'puppet'})
def test_not_found_on_dev(self):
c = MockContext(run={'git merge-base HEAD origin/main': Result('grand_pa')})
c = MockContext(
run={'git fetch origin main': Result(''), 'git merge-base HEAD origin/main': Result('grand_pa')}
)
self.assertEqual(get_ancestor(c, self.package_sizes, False), "ma")

@patch.dict('os.environ', {'CI_COMMIT_REF_NAME': 'main'})
def test_on_main(self):
c = MockContext(run={'git merge-base HEAD origin/main': Result('kirk')})
c = MockContext(run={'git fetch origin main': Result(''), 'git merge-base HEAD origin/main': Result('kirk')})
self.assertEqual(get_ancestor(c, self.package_sizes, True), "kirk")


Expand Down Expand Up @@ -201,6 +205,7 @@ def test_on_main(self, mock_print):
s = PackageSize(arch, flavor, os_name, 2001)
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('12345'),
f"dpkg-deb --info {self.pkg_root}/{flavor}_7_{arch}.{os_name} | grep Installed-Size | cut -d : -f 2 | xargs": Result(
42
Expand All @@ -224,6 +229,7 @@ def test_on_branch_warning(self, mock_print):
s = PackageSize(arch, flavor, os_name, 70000000)
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"rpm -qip {self.pkg_root}/{flavor}-7.{arch}.rpm | grep Size | cut -d : -f 2 | xargs": Result(69000000),
}
Expand All @@ -244,6 +250,7 @@ def test_on_branch_ok_small_diff(self, mock_print):
s = PackageSize(arch, flavor, os_name, 70000000)
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"rpm -qip {self.pkg_root}/{flavor}-7.{arch}.rpm | grep Size | cut -d : -f 2 | xargs": Result(68004999),
}
Expand All @@ -263,6 +270,7 @@ def test_on_branch_ok_rpm(self, mock_print):
s = PackageSize(arch, flavor, os_name, 70000000)
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"rpm -qip {self.pkg_root}/{flavor}-7.{arch}.{os_name} | grep Size | cut -d : -f 2 | xargs": Result(
69000000
Expand All @@ -285,6 +293,7 @@ def test_on_branch_ko(self, mock_print):
s = PackageSize(arch, flavor, os_name, 70000000)
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"rpm -qip {self.pkg_root}/{flavor}-7.{arch}.rpm | grep Size | cut -d : -f 2 | xargs": Result(
139000000
Expand Down
3 changes: 3 additions & 0 deletions tasks/unit_tests/package_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_dev_branch_ko(self, upload_mock):
flavor = 'datadog-agent'
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"dpkg-deb --info {flavor} | grep Installed-Size | cut -d : -f 2 | xargs": Result(42),
f"rpm -qip {flavor} | grep Size | cut -d : -f 2 | xargs": Result(141000000),
Expand All @@ -45,6 +46,7 @@ def test_dev_branch_ok(self, upload_mock):
flavor = 'datadog-agent'
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"dpkg-deb --info {flavor} | grep Installed-Size | cut -d : -f 2 | xargs": Result(42),
f"rpm -qip {flavor} | grep Size | cut -d : -f 2 | xargs": Result(10500000),
Expand All @@ -66,6 +68,7 @@ def test_main_branch_ok(self):
flavor = 'datadog-agent'
c = MockContext(
run={
'git fetch origin main': Result(''),
'git merge-base HEAD origin/main': Result('25'),
f"dpkg-deb --info {flavor} | grep Installed-Size | cut -d : -f 2 | xargs": Result(42),
f"rpm -qip {flavor} | grep Size | cut -d : -f 2 | xargs": Result(20000000),
Expand Down
Loading