Skip to content

Commit 1c6acce

Browse files
authored
feat: add transaction support for nested and singular operations (#269)
1 parent 0714bb8 commit 1c6acce

File tree

4 files changed

+183
-54
lines changed

4 files changed

+183
-54
lines changed

adapter.go

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package gormadapter
1717
import (
1818
"context"
1919
"database/sql"
20-
"errors"
2120
"fmt"
2221
"runtime"
2322
"strings"
@@ -27,6 +26,7 @@ import (
2726
"github.com/casbin/casbin/v2/model"
2827
"github.com/casbin/casbin/v2/persist"
2928
"github.com/glebarez/sqlite"
29+
"github.com/pkg/errors"
3030
"gorm.io/driver/mysql"
3131
"gorm.io/driver/postgres"
3232
"gorm.io/driver/sqlserver"
@@ -703,31 +703,61 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro
703703
}
704704
})
705705
}
706+
707+
// check adapter type
708+
adapter, ok := e.GetAdapter().(*Adapter)
709+
if !ok {
710+
return errors.New("expected adapter of type Adapter, but got incompatible type")
711+
}
712+
713+
// check if we're already in a transaction by checking if the current adapter is a transaction adapter
714+
if _, isTxAdapter := adapter.db.Statement.ConnPool.(*sql.Tx); isTxAdapter {
715+
// we're already in a transaction, just execute the function directly
716+
return fc(e)
717+
}
718+
706719
// lock the transactionMu to ensure the transaction is thread-safe
707720
a.transactionMu.Lock()
708721
defer a.transactionMu.Unlock()
709-
var err error
710-
// reload policy from database to sync with the transaction
711-
defer func() {
712-
e.SetAdapter(a.Copy())
713-
err = e.LoadPolicy()
722+
723+
// save original adapter
724+
originalAdapter := adapter.Copy()
725+
726+
// use GORM transaction functionality
727+
err := adapter.db.Transaction(func(tx *gorm.DB) error {
728+
// create transaction adapter
729+
txAdapter, err := NewAdapterByDB(tx)
730+
if err != nil {
731+
return errors.Wrap(err, "failed to initialize gorm adapter")
732+
}
733+
734+
// temporarily set transaction adapter
735+
e.SetAdapter(txAdapter)
736+
737+
// execute transaction function
738+
err = fc(e)
714739
if err != nil {
715-
panic(err)
740+
return errors.Wrap(err, "failed transactional policy operations")
716741
}
717-
}()
718-
copyDB := *a.db
719-
tx := copyDB.Begin(opts...)
720-
b := a.Copy()
721-
// copy enforcer to set the new adapter with transaction tx
722-
copyEnforcer := e
723-
copyEnforcer.SetAdapter(b)
724-
err = fc(copyEnforcer)
742+
743+
return nil
744+
}, opts...)
745+
746+
// restore original adapter
747+
e.SetAdapter(originalAdapter)
748+
725749
if err != nil {
726-
tx.Rollback()
727-
return err
750+
// LoadPolicy is called only when the transaction encounters an error and fails.
751+
// While this operation is expensive, failures are rare due to validation at earlier layers.
752+
// When a transaction fails, the in-memory model may be out of sync, so LoadPolicy is needed
753+
// to restore consistency by reloading from the database.
754+
if loadErr := e.LoadPolicy(); loadErr != nil {
755+
return errors.Wrap(loadErr, "failed to load policy after transaction failure")
756+
}
757+
return errors.Wrap(err, "transaction execution failed")
728758
}
729-
err = tx.Commit().Error
730-
return err
759+
760+
return nil
731761
}
732762

733763
// RemovePolicies removes multiple policy rules from the storage.

adapter_test.go

Lines changed: 131 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -730,36 +730,125 @@ func TestAddPolicy(t *testing.T) {
730730
}
731731

732732
func TestTransaction(t *testing.T) {
733-
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
734-
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
735-
err := e.GetAdapter().(*Adapter).Transaction(e, func(e casbin.IEnforcer) error {
736-
_, err := e.AddPolicy("jack", "data1", "write")
737-
if err != nil {
738-
return err
739-
}
740-
_, err = e.AddPolicy("jack", "data2", "write")
741-
//err = errors.New("some error")
742-
if err != nil {
733+
// create in-memory database
734+
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
735+
assert.NoError(t, err)
736+
737+
// create adapter
738+
adapter, err := NewAdapterByDB(db)
739+
assert.NoError(t, err)
740+
741+
// create enforcer
742+
enforcer, err := casbin.NewEnforcer("examples/rbac_model.conf", adapter)
743+
assert.NoError(t, err)
744+
745+
// load policy
746+
err = enforcer.LoadPolicy()
747+
assert.NoError(t, err)
748+
749+
// test 1: basic transaction operation
750+
t.Run("Basic Transaction", func(t *testing.T) {
751+
err := adapter.Transaction(enforcer, func(e casbin.IEnforcer) error {
752+
_, err := e.AddPolicy("alice", "data1", "read")
753+
if err != nil {
754+
return err
755+
}
756+
_, err = e.AddPolicy("alice", "data2", "write")
743757
return err
744-
}
745-
return nil
758+
})
759+
assert.NoError(t, err)
760+
761+
// verify policies were added successfully
762+
ok, _ := enforcer.Enforce("alice", "data1", "read")
763+
assert.True(t, ok)
764+
ok, _ = enforcer.Enforce("alice", "data2", "write")
765+
assert.True(t, ok)
766+
})
767+
768+
// test 2: transaction rollback
769+
t.Run("Transaction Rollback", func(t *testing.T) {
770+
err := adapter.Transaction(enforcer, func(e casbin.IEnforcer) error {
771+
_, err := e.AddPolicy("bob", "data3", "read")
772+
if err != nil {
773+
return err
774+
}
775+
// intentionally return error to trigger rollback
776+
return assert.AnError
777+
})
778+
assert.Error(t, err)
779+
780+
// verify policy was rolled back
781+
ok, _ := enforcer.Enforce("bob", "data3", "read")
782+
assert.False(t, ok)
783+
})
784+
785+
// test 3: nested transaction
786+
t.Run("Nested Transaction", func(t *testing.T) {
787+
err := adapter.Transaction(enforcer, func(e casbin.IEnforcer) error {
788+
// outer transaction
789+
_, err := e.AddPolicy("charlie", "data4", "read")
790+
if err != nil {
791+
return err
792+
}
793+
794+
// nested transaction
795+
return adapter.Transaction(e, func(innerE casbin.IEnforcer) error {
796+
_, err := innerE.AddPolicy("charlie", "data5", "write")
797+
if err != nil {
798+
return err
799+
}
800+
_, err = innerE.AddPolicy("charlie", "data6", "delete")
801+
return err
802+
})
803+
})
804+
assert.NoError(t, err)
805+
806+
// verify all policies
807+
ok, _ := enforcer.Enforce("charlie", "data4", "read")
808+
assert.True(t, ok)
809+
ok, _ = enforcer.Enforce("charlie", "data5", "write")
810+
assert.True(t, ok)
811+
ok, _ = enforcer.Enforce("charlie", "data6", "delete")
812+
assert.True(t, ok)
813+
})
814+
815+
// test 4: adapter type check
816+
t.Run("Adapter Type Check", func(t *testing.T) {
817+
// create an incompatible adapter
818+
mockEnforcer, _ := casbin.NewEnforcer("examples/rbac_model.conf")
819+
820+
err := adapter.Transaction(mockEnforcer, func(e casbin.IEnforcer) error {
821+
return nil
822+
})
823+
assert.Error(t, err)
824+
assert.Contains(t, err.Error(), "expected adapter of type Adapter")
746825
})
747-
if err != nil {
748-
return
749-
}
750826
}
751827

752828
func TestTransactionRace(t *testing.T) {
753-
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
754-
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
829+
// create in-memory database for testing
830+
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
831+
require.NoError(t, err)
832+
833+
// create adapter
834+
a, err := NewAdapterByDB(db)
835+
require.NoError(t, err)
836+
837+
// create enforcer
838+
e, err := casbin.NewEnforcer("examples/rbac_model.conf", a)
839+
require.NoError(t, err)
840+
841+
// load policy
842+
err = e.LoadPolicy()
843+
require.NoError(t, err)
755844

756845
concurrency := 100
757846

758847
var g errgroup.Group
759848
for i := 0; i < concurrency; i++ {
760849
i := i
761850
g.Go(func() error {
762-
return e.GetAdapter().(*Adapter).Transaction(e, func(e casbin.IEnforcer) error {
851+
return a.Transaction(e, func(e casbin.IEnforcer) error {
763852
_, err := e.AddPolicy("jack", fmt.Sprintf("data%d", i), "write")
764853
if err != nil {
765854
return err
@@ -781,16 +870,23 @@ func TestTransactionRace(t *testing.T) {
781870
}
782871

783872
func TestTransactionWithSavePolicy(t *testing.T) {
784-
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
785-
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
786-
defer func() {
787-
e.ClearPolicy()
788-
err := e.SavePolicy()
789-
if err != nil {
790-
t.Fatalf("save policy err %v", err)
791-
}
792-
}()
793-
err := e.GetAdapter().(*Adapter).Transaction(e, func(e casbin.IEnforcer) error {
873+
// create in-memory database for testing
874+
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
875+
require.NoError(t, err)
876+
877+
// create adapter
878+
a, err := NewAdapterByDB(db)
879+
require.NoError(t, err)
880+
881+
// create enforcer
882+
e, err := casbin.NewEnforcer("examples/rbac_model.conf", a)
883+
require.NoError(t, err)
884+
885+
// load policy
886+
err = e.LoadPolicy()
887+
require.NoError(t, err)
888+
889+
err = a.Transaction(e, func(e casbin.IEnforcer) error {
794890
_, err := e.AddPolicy("jack", "data1", "write")
795891
if err != nil {
796892
return err
@@ -799,13 +895,13 @@ func TestTransactionWithSavePolicy(t *testing.T) {
799895
if err != nil {
800896
return err
801897
}
802-
err = e.SavePolicy()
803-
if err != nil {
804-
return err
805-
}
806898
return nil
807899
})
808-
if err != nil {
809-
return
810-
}
900+
require.NoError(t, err)
901+
902+
// verify policies were added successfully
903+
ok, _ := e.Enforce("jack", "data1", "write")
904+
require.True(t, ok)
905+
ok, _ = e.Enforce("jack", "data2", "write")
906+
require.True(t, ok)
811907
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ require (
3434
github.com/kr/text v0.1.0 // indirect
3535
github.com/mattn/go-isatty v0.0.17 // indirect
3636
github.com/microsoft/go-mssqldb v1.6.0 // indirect
37+
github.com/pkg/errors v0.9.1 // indirect
3738
github.com/pmezard/go-difflib v1.0.0 // indirect
3839
github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 // indirect
3940
github.com/rogpeppe/go-internal v1.12.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3P
8686
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
8787
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
8888
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
89+
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
90+
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
8991
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
9092
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
9193
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=

0 commit comments

Comments
 (0)