Skip to content

Commit 7750e72

Browse files
authored
Drop create schemas after test (#129)
1 parent aa83826 commit 7750e72

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

src/main.go

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ type tester struct {
102102
mdb *sql.DB
103103
name string
104104

105+
originalSchemas map[string]struct{}
106+
105107
curr *Conn
106108

107109
buf bytes.Buffer
@@ -274,6 +276,20 @@ func (t *tester) preProcess() {
274276
log.Fatalf("Open db err %v", err)
275277
}
276278

279+
if !reserveSchema {
280+
// store original schemas
281+
t.originalSchemas = make(map[string]struct{})
282+
rows, err := mdb.Query("show databases")
283+
if err != nil {
284+
log.Errorf("failed to get databases: %s", err.Error())
285+
return
286+
}
287+
for rows.Next() {
288+
rows.Scan(&dbName)
289+
t.originalSchemas[dbName] = struct{}{}
290+
}
291+
}
292+
277293
dbName = strings.ReplaceAll(t.name, "/", "__")
278294
log.Debugf("Create new db `%s`", dbName)
279295
if _, err = mdb.Exec(fmt.Sprintf("create database `%s`", dbName)); err != nil {
@@ -290,16 +306,30 @@ func (t *tester) preProcess() {
290306
}
291307

292308
func (t *tester) postProcess() {
309+
defer func() {
310+
for _, v := range t.conn {
311+
v.conn.Close()
312+
}
313+
t.mdb.Close()
314+
}()
293315
if !reserveSchema {
294-
_, err := t.mdb.Exec(fmt.Sprintf("drop database `%s`", strings.ReplaceAll(t.name, "/", "__")))
316+
rows, err := t.mdb.Query("show databases")
295317
if err != nil {
296-
log.Errorf("failed to drop database: %s", err.Error())
318+
log.Errorf("failed to get databases: %s", err.Error())
319+
return
320+
}
321+
var dbName string
322+
for rows.Next() {
323+
rows.Scan(&dbName)
324+
if _, exists := t.originalSchemas[dbName]; !exists {
325+
_, err := t.mdb.Exec(fmt.Sprintf("drop database `%s`", dbName))
326+
if err != nil {
327+
log.Errorf("failed to drop database: %s", err.Error())
328+
return
329+
}
330+
}
297331
}
298332
}
299-
for _, v := range t.conn {
300-
v.conn.Close()
301-
}
302-
t.mdb.Close()
303333
}
304334

305335
func (t *tester) addFailure(testSuite *XUnitTestSuite, err *error, cnt int) {

0 commit comments

Comments
 (0)