@@ -730,36 +730,125 @@ func TestAddPolicy(t *testing.T) {
730730}
731731
732732func TestTransaction (t * testing.T ) {
733- a := initAdapter (t , "mysql" , "root:@tcp(127.0.0.1:3306)/" , "casbin" , "casbin_rule" )
734- e , _ := casbin .NewEnforcer ("examples/rbac_model.conf" , a )
735- err := e .GetAdapter ().(* Adapter ).Transaction (e , func (e casbin.IEnforcer ) error {
736- _ , err := e .AddPolicy ("jack" , "data1" , "write" )
737- if err != nil {
738- return err
739- }
740- _ , err = e .AddPolicy ("jack" , "data2" , "write" )
741- //err = errors.New("some error")
742- if err != nil {
733+ // create in-memory database
734+ db , err := gorm .Open (sqlite .Open (":memory:" ), & gorm.Config {})
735+ assert .NoError (t , err )
736+
737+ // create adapter
738+ adapter , err := NewAdapterByDB (db )
739+ assert .NoError (t , err )
740+
741+ // create enforcer
742+ enforcer , err := casbin .NewEnforcer ("examples/rbac_model.conf" , adapter )
743+ assert .NoError (t , err )
744+
745+ // load policy
746+ err = enforcer .LoadPolicy ()
747+ assert .NoError (t , err )
748+
749+ // test 1: basic transaction operation
750+ t .Run ("Basic Transaction" , func (t * testing.T ) {
751+ err := adapter .Transaction (enforcer , func (e casbin.IEnforcer ) error {
752+ _ , err := e .AddPolicy ("alice" , "data1" , "read" )
753+ if err != nil {
754+ return err
755+ }
756+ _ , err = e .AddPolicy ("alice" , "data2" , "write" )
743757 return err
744- }
745- return nil
758+ })
759+ assert .NoError (t , err )
760+
761+ // verify policies were added successfully
762+ ok , _ := enforcer .Enforce ("alice" , "data1" , "read" )
763+ assert .True (t , ok )
764+ ok , _ = enforcer .Enforce ("alice" , "data2" , "write" )
765+ assert .True (t , ok )
766+ })
767+
768+ // test 2: transaction rollback
769+ t .Run ("Transaction Rollback" , func (t * testing.T ) {
770+ err := adapter .Transaction (enforcer , func (e casbin.IEnforcer ) error {
771+ _ , err := e .AddPolicy ("bob" , "data3" , "read" )
772+ if err != nil {
773+ return err
774+ }
775+ // intentionally return error to trigger rollback
776+ return assert .AnError
777+ })
778+ assert .Error (t , err )
779+
780+ // verify policy was rolled back
781+ ok , _ := enforcer .Enforce ("bob" , "data3" , "read" )
782+ assert .False (t , ok )
783+ })
784+
785+ // test 3: nested transaction
786+ t .Run ("Nested Transaction" , func (t * testing.T ) {
787+ err := adapter .Transaction (enforcer , func (e casbin.IEnforcer ) error {
788+ // outer transaction
789+ _ , err := e .AddPolicy ("charlie" , "data4" , "read" )
790+ if err != nil {
791+ return err
792+ }
793+
794+ // nested transaction
795+ return adapter .Transaction (e , func (innerE casbin.IEnforcer ) error {
796+ _ , err := innerE .AddPolicy ("charlie" , "data5" , "write" )
797+ if err != nil {
798+ return err
799+ }
800+ _ , err = innerE .AddPolicy ("charlie" , "data6" , "delete" )
801+ return err
802+ })
803+ })
804+ assert .NoError (t , err )
805+
806+ // verify all policies
807+ ok , _ := enforcer .Enforce ("charlie" , "data4" , "read" )
808+ assert .True (t , ok )
809+ ok , _ = enforcer .Enforce ("charlie" , "data5" , "write" )
810+ assert .True (t , ok )
811+ ok , _ = enforcer .Enforce ("charlie" , "data6" , "delete" )
812+ assert .True (t , ok )
813+ })
814+
815+ // test 4: adapter type check
816+ t .Run ("Adapter Type Check" , func (t * testing.T ) {
817+ // create an incompatible adapter
818+ mockEnforcer , _ := casbin .NewEnforcer ("examples/rbac_model.conf" )
819+
820+ err := adapter .Transaction (mockEnforcer , func (e casbin.IEnforcer ) error {
821+ return nil
822+ })
823+ assert .Error (t , err )
824+ assert .Contains (t , err .Error (), "expected adapter of type Adapter" )
746825 })
747- if err != nil {
748- return
749- }
750826}
751827
752828func TestTransactionRace (t * testing.T ) {
753- a := initAdapter (t , "mysql" , "root:@tcp(127.0.0.1:3306)/" , "casbin" , "casbin_rule" )
754- e , _ := casbin .NewEnforcer ("examples/rbac_model.conf" , a )
829+ // create in-memory database for testing
830+ db , err := gorm .Open (sqlite .Open (":memory:" ), & gorm.Config {})
831+ require .NoError (t , err )
832+
833+ // create adapter
834+ a , err := NewAdapterByDB (db )
835+ require .NoError (t , err )
836+
837+ // create enforcer
838+ e , err := casbin .NewEnforcer ("examples/rbac_model.conf" , a )
839+ require .NoError (t , err )
840+
841+ // load policy
842+ err = e .LoadPolicy ()
843+ require .NoError (t , err )
755844
756845 concurrency := 100
757846
758847 var g errgroup.Group
759848 for i := 0 ; i < concurrency ; i ++ {
760849 i := i
761850 g .Go (func () error {
762- return e . GetAdapter ().( * Adapter ) .Transaction (e , func (e casbin.IEnforcer ) error {
851+ return a .Transaction (e , func (e casbin.IEnforcer ) error {
763852 _ , err := e .AddPolicy ("jack" , fmt .Sprintf ("data%d" , i ), "write" )
764853 if err != nil {
765854 return err
@@ -781,16 +870,23 @@ func TestTransactionRace(t *testing.T) {
781870}
782871
783872func TestTransactionWithSavePolicy (t * testing.T ) {
784- a := initAdapter (t , "mysql" , "root:@tcp(127.0.0.1:3306)/" , "casbin" , "casbin_rule" )
785- e , _ := casbin .NewEnforcer ("examples/rbac_model.conf" , a )
786- defer func () {
787- e .ClearPolicy ()
788- err := e .SavePolicy ()
789- if err != nil {
790- t .Fatalf ("save policy err %v" , err )
791- }
792- }()
793- err := e .GetAdapter ().(* Adapter ).Transaction (e , func (e casbin.IEnforcer ) error {
873+ // create in-memory database for testing
874+ db , err := gorm .Open (sqlite .Open (":memory:" ), & gorm.Config {})
875+ require .NoError (t , err )
876+
877+ // create adapter
878+ a , err := NewAdapterByDB (db )
879+ require .NoError (t , err )
880+
881+ // create enforcer
882+ e , err := casbin .NewEnforcer ("examples/rbac_model.conf" , a )
883+ require .NoError (t , err )
884+
885+ // load policy
886+ err = e .LoadPolicy ()
887+ require .NoError (t , err )
888+
889+ err = a .Transaction (e , func (e casbin.IEnforcer ) error {
794890 _ , err := e .AddPolicy ("jack" , "data1" , "write" )
795891 if err != nil {
796892 return err
@@ -799,13 +895,13 @@ func TestTransactionWithSavePolicy(t *testing.T) {
799895 if err != nil {
800896 return err
801897 }
802- err = e .SavePolicy ()
803- if err != nil {
804- return err
805- }
806898 return nil
807899 })
808- if err != nil {
809- return
810- }
900+ require .NoError (t , err )
901+
902+ // verify policies were added successfully
903+ ok , _ := e .Enforce ("jack" , "data1" , "write" )
904+ require .True (t , ok )
905+ ok , _ = e .Enforce ("jack" , "data2" , "write" )
906+ require .True (t , ok )
811907}
0 commit comments