diff --git a/Cargo.toml b/Cargo.toml index 74b74d3..1b6a7ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,9 +54,10 @@ chrono = "0.4" itertools = "0.13.0" +derive_builder = "0.20.0" + [dev-dependencies] insta = "1.26" -derive_builder = "0.20.0" wiremock = "0.6.0" base64 = "0.22.1" tracing-test = "0.2.4" diff --git a/src/bin/bors.rs b/src/bin/bors.rs index 2863767..0ecc8f3 100644 --- a/src/bin/bors.rs +++ b/src/bin/bors.rs @@ -6,8 +6,9 @@ use std::time::Duration; use anyhow::Context; use bors::{ - create_app, create_bors_process, create_github_client, load_repositories, BorsContext, - BorsGlobalEvent, CommandParser, PgDbClient, ServerState, TeamApiClient, WebhookSecret, + create_app, create_bors_process, create_github_client, create_github_client_from_access_token, + load_repositories, BorsContextBuilder, BorsGlobalEvent, CommandParser, GithubRepoName, + PgDbClient, ServerState, TeamApiClient, WebhookSecret, }; use clap::Parser; use sqlx::postgres::PgConnectOptions; @@ -18,6 +19,8 @@ use tracing_subscriber::filter::EnvFilter; /// How often should the bot check DB state, e.g. for handling timeouts. const PERIODIC_REFRESH: Duration = Duration::from_secs(120); +const GITHUB_API_URL: &str = "https://api.github.com"; + #[derive(clap::Parser)] struct Opts { /// Github App ID. @@ -39,6 +42,10 @@ struct Opts { /// Prefix used for bot commands in PR comments. #[arg(long, env = "CMD_PREFIX", default_value = "@bors")] cmd_prefix: String, + + /// Prefix used for bot commands in PR comments. + #[arg(long, env = "CI_ACCESS_TOKEN")] + ci_access_token: Option, } /// Starts a server that receives GitHub webhooks and generates events into a queue @@ -81,18 +88,34 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { let db = runtime .block_on(initialize_db(&opts.db)) .context("Cannot initialize database")?; - let team_api = TeamApiClient::default(); - let (client, loaded_repos) = runtime.block_on(async { - let client = create_github_client( + let team_api_client = TeamApiClient::default(); + let client = runtime.block_on(async { + create_github_client( opts.app_id.into(), - "https://api.github.com".to_string(), + GITHUB_API_URL.to_string(), opts.private_key.into(), - )?; - let repos = load_repositories(&client, &team_api).await?; - Ok::<_, anyhow::Error>((client, repos)) + ) + })?; + let ci_client = match opts.ci_access_token { + Some(access_token) => { + let client = runtime.block_on(async { + tracing::warn!("creating client ci"); + create_github_client_from_access_token( + GITHUB_API_URL.to_string(), + access_token.into(), + ) + })?; + Some(client) + } + None => None, + }; + let loaded_repos = runtime.block_on(async { + let repos = load_repositories(&client, ci_client.clone(), &team_api_client).await?; + Ok::<_, anyhow::Error>(repos) })?; let mut repos = HashMap::default(); + let mut ci_repo_map: HashMap = HashMap::default(); for (name, repo) in loaded_repos { let repo = match repo { Ok(repo) => { @@ -105,11 +128,27 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { )); } }; + if repo.ci_client.repository() != repo.client.repository() { + ci_repo_map.insert( + repo.ci_client.repository().clone(), + repo.client.repository().clone(), + ); + } repos.insert(name, Arc::new(repo)); } - let ctx = BorsContext::new(CommandParser::new(opts.cmd_prefix), Arc::new(db), repos); - let (repository_tx, global_tx, bors_process) = create_bors_process(ctx, client, team_api); + let ctx = BorsContextBuilder::default() + .parser(CommandParser::new(opts.cmd_prefix)) + .db(Arc::new(db)) + .repositories(repos) + .gh_client(client) + .ci_client(ci_client) + .ci_repo_map(ci_repo_map) + .team_api_client(team_api_client) + .build() + .unwrap(); + + let (repository_tx, global_tx, bors_process) = create_bors_process(ctx); let refresh_tx = global_tx.clone(); let refresh_process = async move { diff --git a/src/bors/command/parser.rs b/src/bors/command/parser.rs index e922f46..ea37ff4 100644 --- a/src/bors/command/parser.rs +++ b/src/bors/command/parser.rs @@ -22,6 +22,7 @@ enum CommandPart<'a> { KeyValue { key: &'a str, value: &'a str }, } +#[derive(Clone)] pub struct CommandParser { prefix: String, } diff --git a/src/bors/comment.rs b/src/bors/comment.rs index 36c509a..7e1c62f 100644 --- a/src/bors/comment.rs +++ b/src/bors/comment.rs @@ -55,3 +55,23 @@ pub fn try_build_in_progress_comment() -> Comment { pub fn cant_find_last_parent_comment() -> Comment { Comment::new(":exclamation: There was no previous build. Please set an explicit parent or remove the `parent=last` argument to use the default parent.".to_string()) } + +pub fn no_try_build_in_progress_comment() -> Comment { + Comment::new(":exclamation: There is currently no try build in progress.".to_string()) +} + +pub fn unclean_try_build_cancelled_comment() -> Comment { + Comment::new( + "Try build was cancelled. It was not possible to cancel some workflows.".to_string(), + ) +} + +pub fn try_build_cancelled_comment(workflow_urls: impl Iterator) -> Comment { + let mut try_build_cancelled_comment = r#"Try build cancelled. +Cancelled workflows:"# + .to_string(); + for url in workflow_urls { + try_build_cancelled_comment += format!("\n- {}", url).as_str(); + } + Comment::new(try_build_cancelled_comment) +} diff --git a/src/bors/context.rs b/src/bors/context.rs index b3a1238..34573de 100644 --- a/src/bors/context.rs +++ b/src/bors/context.rs @@ -3,27 +3,25 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{bors::command::CommandParser, github::GithubRepoName, PgDbClient}; +use derive_builder::Builder; +use octocrab::Octocrab; + +use crate::{bors::command::CommandParser, github::GithubRepoName, PgDbClient, TeamApiClient}; use super::RepositoryState; +#[derive(Builder)] pub struct BorsContext { pub parser: CommandParser, pub db: Arc, + #[builder(field( + ty = "HashMap>", + build = "RwLock::new(self.repositories.clone())" + ))] pub repositories: RwLock>>, -} - -impl BorsContext { - pub fn new( - parser: CommandParser, - db: Arc, - repositories: HashMap>, - ) -> Self { - let repositories = RwLock::new(repositories); - Self { - parser, - db, - repositories, - } - } + pub gh_client: Octocrab, + #[builder(default)] + pub ci_client: Option, + pub ci_repo_map: HashMap, + pub team_api_client: TeamApiClient, } diff --git a/src/bors/handlers/mod.rs b/src/bors/handlers/mod.rs index 7b76da5..c74cf38 100644 --- a/src/bors/handlers/mod.rs +++ b/src/bors/handlers/mod.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use anyhow::Context; -use octocrab::Octocrab; use tracing::Instrument; use crate::bors::command::{BorsCommand, CommandParseError}; @@ -17,7 +16,7 @@ use crate::bors::handlers::workflow::{ handle_check_suite_completed, handle_workflow_completed, handle_workflow_started, }; use crate::bors::{BorsContext, Comment, RepositoryState}; -use crate::{load_repositories, PgDbClient, TeamApiClient}; +use crate::{load_repositories, PgDbClient}; #[cfg(test)] use crate::tests::util::TestSyncMarker; @@ -39,13 +38,12 @@ pub async fn handle_bors_repository_event( ctx: Arc, ) -> anyhow::Result<()> { let db = Arc::clone(&ctx.db); - let Some(repo) = ctx - .repositories - .read() - .unwrap() - .get(event.repository()) - .cloned() - else { + let repo_name = if let Some(repo_name) = ctx.ci_repo_map.get(event.repository()) { + repo_name + } else { + event.repository() + }; + let Some(repo) = ctx.repositories.read().unwrap().get(repo_name).cloned() else { return Err(anyhow::anyhow!( "Repository {} not found in the bot state", event.repository() @@ -142,16 +140,12 @@ pub static WAIT_FOR_REFRESH: TestSyncMarker = TestSyncMarker::new(); pub async fn handle_bors_global_event( event: BorsGlobalEvent, ctx: Arc, - gh_client: &Octocrab, - team_api_client: &TeamApiClient, ) -> anyhow::Result<()> { let db = Arc::clone(&ctx.db); match event { BorsGlobalEvent::InstallationsChanged => { let span = tracing::info_span!("Installations changed"); - reload_repos(ctx, gh_client, team_api_client) - .instrument(span) - .await?; + reload_repos(ctx).instrument(span).await?; } BorsGlobalEvent::Refresh => { let span = tracing::info_span!("Refresh"); @@ -161,7 +155,7 @@ pub async fn handle_bors_global_event( let repo = Arc::clone(&repo); async { let subspan = tracing::info_span!("Repo", repo = repo.repository().to_string()); - refresh_repository(repo, Arc::clone(&db), team_api_client) + refresh_repository(repo, Arc::clone(&db), &ctx.team_api_client) .instrument(subspan) .await } @@ -274,12 +268,9 @@ async fn handle_comment( Ok(()) } -async fn reload_repos( - ctx: Arc, - gh_client: &Octocrab, - team_api_client: &TeamApiClient, -) -> anyhow::Result<()> { - let reloaded_repos = load_repositories(gh_client, team_api_client).await?; +async fn reload_repos(ctx: Arc) -> anyhow::Result<()> { + let reloaded_repos = + load_repositories(&ctx.gh_client, ctx.ci_client.clone(), &ctx.team_api_client).await?; let mut repositories = ctx.repositories.write().unwrap(); for repo in repositories.values() { if !reloaded_repos.contains_key(repo.repository()) { diff --git a/src/bors/handlers/refresh.rs b/src/bors/handlers/refresh.rs index b6c146a..f7957da 100644 --- a/src/bors/handlers/refresh.rs +++ b/src/bors/handlers/refresh.rs @@ -23,6 +23,8 @@ pub async fn refresh_repository( ) { Ok(()) } else { + // FIXME: better error handling + // If a repo failed to be reload, there is no way to know tracing::error!("Failed to refresh repository"); anyhow::bail!("Failed to refresh repository") } diff --git a/src/bors/handlers/trybuild.rs b/src/bors/handlers/trybuild.rs index 9f80ece..05febd6 100644 --- a/src/bors/handlers/trybuild.rs +++ b/src/bors/handlers/trybuild.rs @@ -4,12 +4,16 @@ use anyhow::{anyhow, Context}; use crate::bors::command::Parent; use crate::bors::comment::cant_find_last_parent_comment; +use crate::bors::comment::no_try_build_in_progress_comment; +use crate::bors::comment::try_build_cancelled_comment; use crate::bors::comment::try_build_in_progress_comment; +use crate::bors::comment::unclean_try_build_cancelled_comment; use crate::bors::handlers::labels::handle_label_trigger; use crate::bors::Comment; use crate::bors::RepositoryState; use crate::database::RunId; -use crate::database::{BuildModel, BuildStatus, PullRequestModel, WorkflowStatus, WorkflowType}; +use crate::database::{BuildModel, BuildStatus, PullRequestModel}; +use crate::github::api::client::GithubRepositoryClient; use crate::github::GithubRepoName; use crate::github::{ CommitSha, GithubUser, LabelTrigger, MergeError, PullRequest, PullRequestNumber, @@ -44,8 +48,10 @@ pub(super) async fn command_try_build( return Ok(()); } + // Create pr model based on CI repo, so we can retrieve the pr later when + // the CI repo emits events let pr_model = db - .get_or_create_pull_request(repo.client.repository(), pr.number) + .get_or_create_pull_request(repo.ci_client.repository(), pr.number) .await .context("Cannot find or create PR")?; @@ -62,66 +68,94 @@ pub(super) async fn command_try_build( } }; + match attempt_merge( + &repo.ci_client, + &pr.head.sha, + &base_sha, + &auto_merge_commit_message(pr, repo.client.repository(), "", jobs), + ) + .await? + { + MergeResult::Success(merge_sha) => { + // If the merge was succesful, run CI with merged commit + run_try_build(&repo.ci_client, &db, pr_model, merge_sha.clone(), base_sha).await?; + + handle_label_trigger(repo, pr.number, LabelTrigger::TryBuildStarted).await?; + + repo.client + .post_comment(pr.number, trying_build_comment(&pr.head.sha, &merge_sha)) + .await + } + MergeResult::Conflict => { + repo.client + .post_comment(pr.number, merge_conflict_comment(&pr.head.name)) + .await + } + } +} + +async fn attempt_merge( + ci_client: &GithubRepositoryClient, + head_sha: &CommitSha, + base_sha: &CommitSha, + merge_message: &str, +) -> anyhow::Result { tracing::debug!("Attempting to merge with base SHA {base_sha}"); // First set the try branch to our base commit (either the selected parent or the main branch). - repo.client - .set_branch_to_sha(TRY_MERGE_BRANCH_NAME, &base_sha) + ci_client + .set_branch_to_sha(TRY_MERGE_BRANCH_NAME, base_sha) .await .map_err(|error| anyhow!("Cannot set try merge branch to {}: {error:?}", base_sha.0))?; // Then merge the PR commit into the try branch - match repo - .client - .merge_branches( - TRY_MERGE_BRANCH_NAME, - &pr.head.sha, - &auto_merge_commit_message(pr, repo.client.repository(), "", jobs), - ) + match ci_client + .merge_branches(TRY_MERGE_BRANCH_NAME, head_sha, merge_message) .await { Ok(merge_sha) => { tracing::debug!("Merge successful, SHA: {merge_sha}"); - // If the merge was succesful, then set the actual try branch that will run CI to the - // merged commit. - repo.client - .set_branch_to_sha(TRY_BRANCH_NAME, &merge_sha) - .await - .map_err(|error| anyhow!("Cannot set try branch to main branch: {error:?}"))?; - db.attach_try_build( - pr_model, - TRY_BRANCH_NAME.to_string(), - merge_sha.clone(), - base_sha.clone(), - ) - .await?; - tracing::info!("Try build started"); - - handle_label_trigger(repo, pr.number, LabelTrigger::TryBuildStarted).await?; - - let comment = Comment::new(format!( - ":hourglass: Trying commit {} with merge {}…", - pr.head.sha.clone(), - merge_sha - )); - repo.client.post_comment(pr.number, comment).await?; - Ok(()) + Ok(MergeResult::Success(merge_sha)) } Err(MergeError::Conflict) => { tracing::warn!("Merge conflict"); - repo.client - .post_comment( - pr.number, - Comment::new(merge_conflict_message(&pr.head.name)), - ) - .await?; - Ok(()) + + Ok(MergeResult::Conflict) } Err(error) => Err(error.into()), } } +async fn run_try_build( + ci_client: &GithubRepositoryClient, + db: &PgDbClient, + pr_model: PullRequestModel, + commit_sha: CommitSha, + parent_sha: CommitSha, +) -> anyhow::Result<()> { + ci_client + .set_branch_to_sha(TRY_BRANCH_NAME, &commit_sha) + .await + .map_err(|error| anyhow!("Cannot set try branch to main branch: {error:?}"))?; + + db.attach_try_build( + pr_model, + TRY_BRANCH_NAME.to_string(), + commit_sha, + parent_sha, + ) + .await?; + + tracing::info!("Try build started"); + Ok(()) +} + +enum MergeResult { + Success(CommitSha), + Conflict, +} + fn get_base_sha( pr_model: &PullRequestModel, parent: Option, @@ -162,18 +196,13 @@ pub(super) async fn command_try_cancel( let pr_number: PullRequestNumber = pr.number; let pr = db - .get_or_create_pull_request(repo.client.repository(), pr_number) + .get_or_create_pull_request(repo.ci_client.repository(), pr_number) .await?; let Some(build) = get_pending_build(pr) else { tracing::warn!("No build found"); repo.client - .post_comment( - pr_number, - Comment::new( - ":exclamation: There is currently no try build in progress.".to_string(), - ), - ) + .post_comment(pr_number, no_try_build_in_progress_comment()) .await?; return Ok(()); }; @@ -187,13 +216,7 @@ pub(super) async fn command_try_cancel( db.update_build_status(&build, BuildStatus::Cancelled) .await?; repo.client - .post_comment( - pr_number, - Comment::new( - "Try build was cancelled. It was not possible to cancel some workflows." - .to_string(), - ), - ) + .post_comment(pr_number, unclean_try_build_cancelled_comment()) .await? } Ok(workflow_ids) => { @@ -201,16 +224,13 @@ pub(super) async fn command_try_cancel( .await?; tracing::info!("Try build cancelled"); - let mut try_build_cancelled_comment = r#"Try build cancelled. -Cancelled workflows:"# - .to_string(); - for id in workflow_ids { - let url = repo.client.get_workflow_url(id); - try_build_cancelled_comment += format!("\n- {}", url).as_str(); - } - repo.client - .post_comment(pr_number, Comment::new(try_build_cancelled_comment)) + .post_comment( + pr_number, + try_build_cancelled_comment( + repo.ci_client.get_workflow_urls(workflow_ids.into_iter()), + ), + ) .await? } }; @@ -223,16 +243,10 @@ pub async fn cancel_build_workflows( db: &PgDbClient, build: &BuildModel, ) -> anyhow::Result> { - let pending_workflows = db - .get_workflows_for_build(build) - .await? - .into_iter() - .filter(|w| w.status == WorkflowStatus::Pending && w.workflow_type == WorkflowType::Github) - .map(|w| w.run_id) - .collect::>(); + let pending_workflows = db.get_pending_workflows_for_build(build).await?; tracing::info!("Cancelling workflows {:?}", pending_workflows); - repo.client.cancel_workflows(&pending_workflows).await?; + repo.ci_client.cancel_workflows(&pending_workflows).await?; Ok(pending_workflows) } @@ -267,8 +281,14 @@ fn auto_merge_commit_message( message } -fn merge_conflict_message(branch: &str) -> String { - format!( +fn trying_build_comment(head_sha: &CommitSha, merge_sha: &CommitSha) -> Comment { + Comment::new(format!( + ":hourglass: Trying commit {head_sha} with merge {merge_sha}…" + )) +} + +fn merge_conflict_comment(branch: &str) -> Comment { + let message = format!( r#":lock: Merge conflict This pull request and the master branch diverged in a way that cannot @@ -298,7 +318,8 @@ handled during merge and rebase. This is normal, and you should still perform st "# - ) + ); + Comment::new(message) } async fn check_try_permissions( diff --git a/src/bors/handlers/workflow.rs b/src/bors/handlers/workflow.rs index 00b5871..04ad5fc 100644 --- a/src/bors/handlers/workflow.rs +++ b/src/bors/handlers/workflow.rs @@ -135,7 +135,7 @@ async fn try_complete_build( // Ask GitHub what are all the check suites attached to the given commit. // This tells us for how many workflows we should wait. let checks = repo - .client + .ci_client .get_check_suites_for_commit(&payload.branch, &payload.commit_sha) .await?; diff --git a/src/bors/mod.rs b/src/bors/mod.rs index 719a3bd..fdd02cc 100644 --- a/src/bors/mod.rs +++ b/src/bors/mod.rs @@ -3,6 +3,7 @@ use arc_swap::ArcSwap; pub use command::CommandParser; pub use comment::Comment; pub use context::BorsContext; +pub use context::BorsContextBuilder; #[cfg(test)] pub use handlers::WAIT_FOR_REFRESH; pub use handlers::{handle_bors_global_event, handle_bors_repository_event}; @@ -32,11 +33,14 @@ pub struct CheckSuite { pub(crate) status: CheckSuiteStatus, } -/// An access point to a single repository. +/// An access point to state of a single request /// Can be used to query permissions for the repository, and also to perform various /// actions using the stored client. pub struct RepositoryState { + /// Client of the main repo pub client: GithubRepositoryClient, + /// Client of the ci dedicated repo + pub ci_client: GithubRepositoryClient, pub permissions: ArcSwap, pub config: ArcSwap, } diff --git a/src/config.rs b/src/config.rs index 63b95cb..f95c0a2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,7 +4,7 @@ use std::time::Duration; use serde::de::Error; use serde::{Deserialize, Deserializer}; -use crate::github::{LabelModification, LabelTrigger}; +use crate::github::{GithubRepoName, LabelModification, LabelTrigger}; pub const CONFIG_FILE_PATH: &str = "rust-bors.toml"; @@ -17,6 +17,7 @@ pub struct RepositoryConfig { deserialize_with = "deserialize_duration_from_secs" )] pub timeout: Duration, + pub ci_repo: Option, #[serde(default, deserialize_with = "deserialize_labels")] pub labels: HashMap>, } diff --git a/src/database/client.rs b/src/database/client.rs index 8849fc3..184e32c 100644 --- a/src/database/client.rs +++ b/src/database/client.rs @@ -140,4 +140,20 @@ impl PgDbClient { ) -> anyhow::Result> { get_workflows_for_build(&self.pool, build.id).await } + + pub async fn get_pending_workflows_for_build( + &self, + build: &BuildModel, + ) -> anyhow::Result> { + let workflows = self + .get_workflows_for_build(build) + .await? + .into_iter() + .filter(|w| { + w.status == WorkflowStatus::Pending && w.workflow_type == WorkflowType::Github + }) + .map(|w| w.run_id) + .collect::>(); + Ok(workflows) + } } diff --git a/src/github/api/client.rs b/src/github/api/client.rs index 65a7787..cd84522 100644 --- a/src/github/api/client.rs +++ b/src/github/api/client.rs @@ -279,6 +279,14 @@ impl GithubRepositoryClient { format!("{html_url}/actions/runs/{run_id}") } + /// Get workflow url for a list of workflows. + pub fn get_workflow_urls<'a>( + &'a self, + run_ids: impl Iterator + 'a, + ) -> impl Iterator + 'a { + run_ids.map(|workflow_id| self.get_workflow_url(workflow_id)) + } + fn format_pr(&self, pr: PullRequestNumber) -> String { format!("{}/{}", self.repository(), pr) } @@ -312,7 +320,9 @@ mod tests { .await; let client = mock.github_client(); let team_api_client = mock.team_api_client(); - let mut repos = load_repositories(&client, &team_api_client).await.unwrap(); + let mut repos = load_repositories(&client, None, &team_api_client) + .await + .unwrap(); assert_eq!(repos.len(), 2); let repo = repos diff --git a/src/github/api/mod.rs b/src/github/api/mod.rs index 7ef3103..e575163 100644 --- a/src/github/api/mod.rs +++ b/src/github/api/mod.rs @@ -36,15 +36,27 @@ pub fn create_github_client( .context("Could not create octocrab builder") } +pub fn create_github_client_from_access_token( + github_url: String, + access_token: SecretString, +) -> anyhow::Result { + Octocrab::builder() + .base_uri(github_url)? + .user_access_token(access_token) + .build() + .context("Could not create octocrab builder") +} + /// Loads repositories that are connected to the given GitHub App client. /// The anyhow::Result is intended, because we wanted to have /// a hard error when the repos fail to load when the bot starts, but only log /// a warning when we reload the state during the bot's execution. pub async fn load_repositories( - client: &Octocrab, + gh_client: &Octocrab, + ci_client: Option, team_api_client: &TeamApiClient, ) -> anyhow::Result>> { - let installations = client + let installations = gh_client .apps() .installations() .send() @@ -53,7 +65,7 @@ pub async fn load_repositories( // installation client can not be used to load current app // https://docs.github.com/en/rest/apps/apps?apiVersion=2022-11-28#get-the-authenticated-app - let app = client + let app = gh_client .current() .app() .await @@ -61,7 +73,7 @@ pub async fn load_repositories( let mut repositories = HashMap::default(); for installation in installations { - let installation_client = client + let installation_client = gh_client .installation(installation.id) .context("failed to install client")?; @@ -75,32 +87,31 @@ pub async fn load_repositories( } }; for repo in repos { - let name = match parse_repo_name(&repo) { + let repo_name = match parse_repo_name(&repo) { Ok(name) => name, Err(error) => { tracing::error!("Found repository without a name: {error:?}"); continue; } }; - - if repositories.contains_key(&name) { + if repositories.contains_key(&repo_name) { return Err(anyhow::anyhow!( - "Repository {name} found in multiple installations!", + "Repository {repo_name} found in multiple installations!", )); } - let repo_state = create_repo_state( app.clone(), installation_client.clone(), + ci_client.clone(), team_api_client, - repo.clone(), - name.clone(), + repo_name.clone(), + &repo, ) .await .map_err(|error| { anyhow::anyhow!("Cannot load repository {:?}: {error:?}", repo.full_name) }); - repositories.insert(name, repo_state); + repositories.insert(repo_name, repo_state); } } Ok(repositories) @@ -142,24 +153,40 @@ fn parse_repo_name(repo: &Repository) -> anyhow::Result { async fn create_repo_state( app: App, - repo_client: Octocrab, + gh_app_installation_client: Octocrab, + ci_client: Option, team_api_client: &TeamApiClient, - repo: Repository, name: GithubRepoName, + repo: &Repository, ) -> anyhow::Result { tracing::info!("Found repository {name}"); - let client = GithubRepositoryClient::new(app, repo_client, name.clone(), repo); + let ci_client = ci_client.unwrap_or(gh_app_installation_client.clone()); + + let main_repo_client = GithubRepositoryClient::new( + app.clone(), + gh_app_installation_client.clone(), + name.clone(), + repo.clone(), + ); + let config = load_config(&main_repo_client).await?; + let ci_repo = get_ci_repo(config.ci_repo.clone(), repo.clone(), &ci_client).await?; + + let ci_repo_client = GithubRepositoryClient::new( + app, + ci_client, + config.ci_repo.clone().unwrap_or(name.clone()), + ci_repo, + ); let permissions = team_api_client .load_permissions(&name) .await .with_context(|| format!("Could not load permissions for repository {name}"))?; - let config = load_config(&client).await?; - Ok(RepositoryState { - client, + client: main_repo_client, + ci_client: ci_repo_client, config: ArcSwap::new(Arc::new(config)), permissions: ArcSwap::new(Arc::new(permissions)), }) @@ -177,3 +204,18 @@ async fn load_config(client: &GithubRepositoryClient) -> anyhow::Result, + default_client: Repository, + ci_client: &Octocrab, +) -> anyhow::Result { + let Some(ci_repo_name) = ci_repo_name else { + return Ok(default_client); + }; + let ci_repo = ci_client + .repos(ci_repo_name.owner(), ci_repo_name.name()) + .get() + .await?; + Ok(ci_repo) +} diff --git a/src/github/mod.rs b/src/github/mod.rs index 23e5280..07708e0 100644 --- a/src/github/mod.rs +++ b/src/github/mod.rs @@ -3,6 +3,7 @@ use std::fmt::{Debug, Display, Formatter}; use octocrab::models::UserId; +use serde::Deserialize; use url::Url; pub mod api; @@ -53,6 +54,16 @@ impl From for GithubRepoName { } } +impl<'de> Deserialize<'de> for GithubRepoName { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let full_name = String::deserialize(deserializer)?; + Ok(GithubRepoName::from(full_name)) + } +} + #[derive(Clone, Debug, PartialEq)] pub struct GithubUser { pub id: UserId, diff --git a/src/github/server.rs b/src/github/server.rs index ebb157c..13dab2d 100644 --- a/src/github/server.rs +++ b/src/github/server.rs @@ -2,7 +2,7 @@ use crate::bors::event::BorsEvent; use crate::bors::{handle_bors_global_event, handle_bors_repository_event, BorsContext}; use crate::github::webhook::GitHubWebhook; use crate::github::webhook::WebhookSecret; -use crate::{BorsGlobalEvent, BorsRepositoryEvent, TeamApiClient}; +use crate::{BorsGlobalEvent, BorsRepositoryEvent}; use anyhow::Error; use axum::extract::State; @@ -10,7 +10,6 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::Router; -use octocrab::Octocrab; use std::future::Future; use std::sync::Arc; use tokio::sync::mpsc; @@ -83,8 +82,6 @@ pub async fn github_webhook_handler( /// them. pub fn create_bors_process( ctx: BorsContext, - gh_client: Octocrab, - team_api: TeamApiClient, ) -> ( mpsc::Sender, mpsc::Sender, @@ -104,7 +101,7 @@ pub fn create_bors_process( { tokio::join!( consume_repository_events(ctx.clone(), repository_rx), - consume_global_events(ctx.clone(), global_rx, gh_client, team_api) + consume_global_events(ctx.clone(), global_rx) ); } // In real execution, the bot runs forever. If there is something that finishes @@ -116,7 +113,7 @@ pub fn create_bors_process( _ = consume_repository_events(ctx.clone(), repository_rx) => { tracing::error!("Repository event handling process has ended"); } - _ = consume_global_events(ctx.clone(), global_rx, gh_client, team_api) => { + _ = consume_global_events(ctx.clone(), global_rx) => { tracing::error!("Global event handling process has ended"); } } @@ -146,15 +143,13 @@ async fn consume_repository_events( async fn consume_global_events( ctx: Arc, mut global_rx: mpsc::Receiver, - gh_client: Octocrab, - team_api: TeamApiClient, ) { while let Some(event) = global_rx.recv().await { let ctx = ctx.clone(); let span = tracing::info_span!("GlobalEvent"); tracing::debug!("Received global event: {event:#?}"); - if let Err(error) = handle_bors_global_event(event, ctx, &gh_client, &team_api) + if let Err(error) = handle_bors_global_event(event, ctx) .instrument(span.clone()) .await { diff --git a/src/lib.rs b/src/lib.rs index 5e85551..f0124bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,13 +8,17 @@ mod github; mod permissions; mod utils; -pub use bors::{event::BorsGlobalEvent, event::BorsRepositoryEvent, BorsContext, CommandParser}; +pub use bors::{ + event::BorsGlobalEvent, event::BorsRepositoryEvent, BorsContext, BorsContextBuilder, + CommandParser, +}; pub use database::PgDbClient; pub use github::{ api::create_github_client, + api::create_github_client_from_access_token, api::load_repositories, server::{create_app, create_bors_process, ServerState}, - WebhookSecret, + GithubRepoName, WebhookSecret, }; pub use permissions::TeamApiClient; diff --git a/src/permissions.rs b/src/permissions.rs index 4c7022f..657dd5b 100644 --- a/src/permissions.rs +++ b/src/permissions.rs @@ -37,6 +37,7 @@ pub(crate) struct UserPermissionsResponse { github_ids: HashSet, } +#[derive(Clone)] pub struct TeamApiClient { base_url: String, } diff --git a/src/tests/mocks/bors.rs b/src/tests/mocks/bors.rs index a0f8ad5..6ff539d 100644 --- a/src/tests/mocks/bors.rs +++ b/src/tests/mocks/bors.rs @@ -24,8 +24,8 @@ use crate::tests::mocks::{ }; use crate::tests::webhook::{create_webhook_request, TEST_WEBHOOK_SECRET}; use crate::{ - create_app, create_bors_process, BorsContext, BorsGlobalEvent, CommandParser, PgDbClient, - ServerState, WebhookSecret, + create_app, create_bors_process, BorsContextBuilder, BorsGlobalEvent, CommandParser, + PgDbClient, ServerState, WebhookSecret, }; use super::pull_request::{GitHubPullRequestEventPayload, PullRequestChangeEvent}; @@ -100,7 +100,7 @@ impl BorsTester { let mock = ExternalHttpMock::start(&world).await; let db = Arc::new(PgDbClient::new(pool)); - let loaded_repos = load_repositories(&mock.github_client(), &mock.team_api_client()) + let loaded_repos = load_repositories(&mock.github_client(), None, &mock.team_api_client()) .await .unwrap(); let mut repos = HashMap::default(); @@ -109,10 +109,16 @@ impl BorsTester { repos.insert(name, Arc::new(repo)); } - let ctx = BorsContext::new(CommandParser::new("@bors".to_string()), db.clone(), repos); + let ctx = BorsContextBuilder::default() + .parser(CommandParser::new("@bors".to_string())) + .db(db.clone()) + .repositories(repos) + .gh_client(mock.github_client()) + .team_api_client(mock.team_api_client()) + .build() + .unwrap(); - let (repository_tx, global_tx, bors_process) = - create_bors_process(ctx, mock.github_client(), mock.team_api_client()); + let (repository_tx, global_tx, bors_process) = create_bors_process(ctx); let state = ServerState::new( repository_tx,