From c70d82a31b19d5594e08800d05c4e7dde3ead435 Mon Sep 17 00:00:00 2001 From: masato Date: Wed, 1 Jan 2025 03:34:09 +0900 Subject: [PATCH] Make a query overridable. --- refinery_core/src/runner.rs | 2 +- refinery_core/src/traits/async.rs | 23 +++++++++++------------ refinery_core/src/traits/sync.rs | 28 ++++++++++++++++------------ 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/refinery_core/src/runner.rs b/refinery_core/src/runner.rs index 46d4d714..af8e2ab5 100644 --- a/refinery_core/src/runner.rs +++ b/refinery_core/src/runner.rs @@ -204,7 +204,7 @@ pub struct Report { impl Report { /// Instantiate a new Report - pub(crate) fn new(applied_migrations: Vec) -> Report { + pub fn new(applied_migrations: Vec) -> Report { Report { applied_migrations } } diff --git a/refinery_core/src/traits/async.rs b/refinery_core/src/traits/async.rs index 8af3430f..32908b83 100644 --- a/refinery_core/src/traits/async.rs +++ b/refinery_core/src/traits/async.rs @@ -132,15 +132,20 @@ where ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) } + fn get_last_applied_migration_query(migration_table_name: &str) -> String { + GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + + fn get_applied_migrations_query(migration_table_name: &str) -> String { + GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + async fn get_last_applied_migration( &mut self, migration_table_name: &str, ) -> Result, Error> { let mut migrations = self - .query( - &GET_LAST_APPLIED_MIGRATION_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name), - ) + .query(Self::get_last_applied_migration_query(migration_table_name).as_str()) .await .migration_err("error getting last applied migration", None)?; @@ -152,10 +157,7 @@ where migration_table_name: &str, ) -> Result, Error> { let migrations = self - .query( - &GET_APPLIED_MIGRATIONS_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name), - ) + .query(Self::get_applied_migrations_query(migration_table_name).as_str()) .await .migration_err("error getting applied migrations", None)?; @@ -178,10 +180,7 @@ where .migration_err("error asserting migrations table", None)?; let applied_migrations = self - .query( - &GET_APPLIED_MIGRATIONS_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name), - ) + .get_applied_migrations(migration_table_name) .await .migration_err("error getting current schema version", None)?; diff --git a/refinery_core/src/traits/sync.rs b/refinery_core/src/traits/sync.rs index 7e552b2b..bf761035 100644 --- a/refinery_core/src/traits/sync.rs +++ b/refinery_core/src/traits/sync.rs @@ -92,14 +92,24 @@ pub trait Migrate: Query> where Self: Sized, { + // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table + fn assert_migrations_table_query(migration_table_name: &str) -> String { + ASSERT_MIGRATIONS_TABLE_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + + fn get_last_applied_migration_query(migration_table_name: &str) -> String { + GET_LAST_APPLIED_MIGRATION_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + + fn get_applied_migrations_query(migration_table_name: &str) -> String { + GET_APPLIED_MIGRATIONS_QUERY.replace("%MIGRATION_TABLE_NAME%", migration_table_name) + } + fn assert_migrations_table(&mut self, migration_table_name: &str) -> Result { // Needed cause some database vendors like Mssql have a non sql standard way of checking the migrations table, // thou on this case it's just to be consistent with the async trait `AsyncMigrate` self.execute( - [ASSERT_MIGRATIONS_TABLE_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name) - .as_str()] - .into_iter(), + [Self::assert_migrations_table_query(migration_table_name).as_str()].into_iter(), ) .migration_err("error asserting migrations table", None) } @@ -109,10 +119,7 @@ where migration_table_name: &str, ) -> Result, Error> { let mut migrations = self - .query( - &GET_LAST_APPLIED_MIGRATION_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name), - ) + .query(Self::get_last_applied_migration_query(migration_table_name).as_str()) .migration_err("error getting last applied migration", None)?; Ok(migrations.pop()) @@ -123,10 +130,7 @@ where migration_table_name: &str, ) -> Result, Error> { let migrations = self - .query( - &GET_APPLIED_MIGRATIONS_QUERY - .replace("%MIGRATION_TABLE_NAME%", migration_table_name), - ) + .query(Self::get_applied_migrations_query(migration_table_name).as_str()) .migration_err("error getting applied migrations", None)?; Ok(migrations)