Skip to content

Commit

Permalink
Merge pull request #293 from cashapp/mtocker-asserty
Browse files Browse the repository at this point in the history
Add asserty feature
  • Loading branch information
morgo authored Jun 2, 2024
2 parents 2bce19b + d3923ca commit 2658705
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pkg/table/asserty/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# What is this?

Asserty is a package that provides a set of functions to help ensure that a database schema conforms to your expectations.

Sure, you could probably use it in tests. But our goal for this is to use it as part of service readiness checks. We want to make sure that the database schema is what we expect it to be before we start to serve traffic.
60 changes: 60 additions & 0 deletions pkg/table/asserty/asserty.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Package asserty offers functionality to assert for certain DB properties.
package asserty

import (
"context"
"database/sql"
"errors"
"slices"

"github.com/cashapp/spirit/pkg/table"
_ "github.com/go-sql-driver/mysql"
)

type Table struct {
ti *table.TableInfo
}

func LoadTable(db *sql.DB, schema, tableName string) (*Table, error) {
ti := table.NewTableInfo(db, schema, tableName)
if err := ti.SetInfo(context.TODO()); err != nil {
return nil, err
}
return &Table{ti: ti}, nil
}

func (t *Table) ContainsColumns(columnNames ...string) error {
for _, col := range columnNames {
if !slices.Contains(t.ti.Columns, col) {
return errors.New("missing column " + col + " on table " + t.ti.QuotedName)
}
}
return nil
}

func (t *Table) NotContainsColumns(columnNames ...string) error {
for _, col := range columnNames {
if slices.Contains(t.ti.Columns, col) {
return errors.New("unexpected column " + col + " on table " + t.ti.QuotedName)
}
}
return nil
}

func (t *Table) ContainsIndexes(indexNames ...string) error {
for _, idx := range indexNames {
if !slices.Contains(t.ti.Indexes, idx) {
return errors.New("missing index " + idx + " on table " + t.ti.QuotedName)
}
}
return nil
}

func (t *Table) NotContainsIndexes(indexNames ...string) error {
for _, idx := range indexNames {
if slices.Contains(t.ti.Indexes, idx) {
return errors.New("unexpected index " + idx + " on table " + t.ti.QuotedName)
}
}
return nil
}
39 changes: 39 additions & 0 deletions pkg/table/asserty/asserty_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package asserty

import (
"database/sql"
"testing"

"github.com/cashapp/spirit/pkg/testutils"
"github.com/stretchr/testify/require"
)

func TestBasicUsage(t *testing.T) {
db, err := sql.Open("mysql", testutils.DSN())
require.NoError(t, err)
defer db.Close()

testutils.RunSQL(t, `DROP TABLE IF EXISTS asserty_test`)
table := `CREATE TABLE asserty_test (
id int(11) unsigned NOT NULL AUTO_INCREMENT,
name varchar(115),
age int(11) NOT NULL,
PRIMARY KEY (id, age),
INDEX idx_name (name),
INDEX idx_age (age)
)`
testutils.RunSQL(t, table)

tbl, err := LoadTable(db, "test", "asserty_test")
require.NoError(t, err)
require.NoError(t, tbl.ContainsColumns("name", "age"))
require.Error(t, tbl.ContainsColumns("name", "somethingelse"))
require.NoError(t, tbl.NotContainsColumns("col1", "col2"))
require.Error(t, tbl.NotContainsColumns("col1", "name", "col2"))

require.NoError(t, tbl.ContainsIndexes("idx_name", "idx_age"))
require.Error(t, tbl.ContainsIndexes("idx_name", "idx_age", "idx_somethingelse"))
require.NoError(t, tbl.NotContainsIndexes("idx_col1", "idx_col2"))
require.Error(t, tbl.NotContainsIndexes("idx_col1", "idx_col2", "idx_name"))
require.Error(t, tbl.NotContainsIndexes("idx_name"))
}
27 changes: 27 additions & 0 deletions pkg/table/tableinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type TableInfo struct {
TableName string
QuotedName string
Columns []string // all the column names
Indexes []string // all the index names
columnsMySQLTps map[string]string // map from column name to MySQL type
KeyColumns []string // the column names of the primaryKey
keyColumnsMySQLTp []string // the MySQL types of the primaryKey
Expand Down Expand Up @@ -86,6 +87,9 @@ func (t *TableInfo) SetInfo(ctx context.Context) error {
if err := t.setPrimaryKey(ctx); err != nil {
return err
}
if err := t.setIndexes(ctx); err != nil {
return err
}
return t.setMinMax(ctx)
}

Expand All @@ -106,6 +110,29 @@ func (t *TableInfo) setRowEstimate(ctx context.Context) error {
return nil
}

func (t *TableInfo) setIndexes(ctx context.Context) error {
rows, err := t.db.QueryContext(ctx, "SELECT DISTINCT INDEX_NAME FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema=? AND table_name=? AND index_name != 'PRIMARY'",
t.SchemaName,
t.TableName,
)
if err != nil {
return err
}
defer rows.Close()
t.Indexes = []string{}
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return err
}
t.Indexes = append(t.Indexes, name)
}
if rows.Err() != nil {
return rows.Err()
}
return nil
}

func (t *TableInfo) setColumns(ctx context.Context) error {
rows, err := t.db.QueryContext(ctx, "SELECT column_name, column_type FROM information_schema.columns WHERE table_schema=? AND table_name=? ORDER BY ORDINAL_POSITION",
t.SchemaName,
Expand Down

0 comments on commit 2658705

Please sign in to comment.