Skip to content

Commit

Permalink
nn: Fix MNIST (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jan 4, 2025
1 parent 968f3ff commit b5d6890
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
uses: actions/checkout@v2

- name: Test
run: go test -timeout 20m -v ./... -coverprofile=coverage.txt -covermode=atomic -coverpkg=./...
run: go test -timeout 30m -v ./... -coverprofile=coverage.txt -covermode=atomic -coverpkg=./...
env:
# MySQL
MYSQL_URI: mysql://root:password@tcp(localhost:${{ job.services.mysql.ports[3306] }})/
Expand Down
4 changes: 3 additions & 1 deletion common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package nn

import "github.com/chewxy/math32"

type Layer interface {
Parameters() []*Tensor
Forward(x *Tensor) *Tensor
Expand All @@ -28,7 +30,7 @@ type linearLayer struct {

func NewLinear(in, out int) Layer {
return &linearLayer{
w: Rand(in, out).RequireGrad(),
w: Normal(0, 1.0/math32.Sqrt(float32(in)), in, out).RequireGrad(),
b: Zeros(out).RequireGrad(),
}
}
Expand Down
127 changes: 121 additions & 6 deletions common/nn/nn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
package nn

import (
"bufio"
"encoding/csv"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"testing"

"github.com/chewxy/math32"
"github.com/samber/lo"
"github.com/schollz/progressbar/v3"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/common/dataset"
"github.com/zhenghaoz/gorse/common/util"
"os"
"path/filepath"
"testing"
)

func TestLinearRegression(t *testing.T) {
Expand All @@ -47,9 +54,9 @@ func TestLinearRegression(t *testing.T) {
}

assert.Equal(t, []int{1, 1}, w.shape)
assert.InDelta(t, float64(2), w.data[0], 0.5)
assert.InDelta(t, float64(2), w.data[0], 0.6)
assert.Equal(t, []int{1}, b.shape)
assert.InDelta(t, float64(5), b.data[0], 0.5)
assert.InDelta(t, float64(5), b.data[0], 0.6)
}

func TestNeuralNetwork(t *testing.T) {
Expand All @@ -76,7 +83,7 @@ func TestNeuralNetwork(t *testing.T) {
optimizer.Step()
l = loss.data[0]
}
assert.InDelta(t, float64(0), l, 0.1)
assert.InDelta(t, float64(0), l, 0.2)
}

func iris() (*Tensor, *Tensor, error) {
Expand Down Expand Up @@ -139,3 +146,111 @@ func TestIris(t *testing.T) {
}
assert.InDelta(t, float32(0), l, 0.1)
}

func mnist() (lo.Tuple2[*Tensor, *Tensor], lo.Tuple2[*Tensor, *Tensor], error) {
var train, test lo.Tuple2[*Tensor, *Tensor]
// Download and unzip dataset
path, err := dataset.DownloadAndUnzip("mnist")
if err != nil {
return train, test, err
}
// Open dataset
train.A, train.B, err = openMNISTFile(filepath.Join(path, "train.libfm"))
if err != nil {
return train, test, err
}
test.A, test.B, err = openMNISTFile(filepath.Join(path, "test.libfm"))
if err != nil {
return train, test, err
}
return train, test, nil
}

func openMNISTFile(path string) (*Tensor, *Tensor, error) {
// Open file
f, err := os.Open(path)
if err != nil {
return nil, nil, err
}
defer f.Close()
// Read data line by line
var (
images []float32
labels []float32
)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
splits := strings.Split(line, " ")
// Parse label
label, err := util.ParseFloat[float32](splits[0])
if err != nil {
return nil, nil, err
}
labels = append(labels, label)
// Parse image
image := make([]float32, 784)
for _, split := range splits[1:] {
kv := strings.Split(split, ":")
index, err := strconv.Atoi(kv[0])
if err != nil {
return nil, nil, err
}
value, err := util.ParseFloat[float32](kv[1])
if err != nil {
return nil, nil, err
}
image[index] = value
}
images = append(images, image...)
}
return NewTensor(images, len(labels), 784), NewTensor(labels, len(labels)), nil
}

func TestMNIST(t *testing.T) {
train, test, err := mnist()
assert.NoError(t, err)

model := NewSequential(
NewLinear(784, 1000),
NewReLU(),
NewLinear(1000, 10),
)
optimizer := NewAdam(model.Parameters(), 0.001)

var (
sumLoss float32
batchSize = 1000
)
for i := 0; i < 3; i++ {
sumLoss = 0
bar := progressbar.Default(int64(train.A.shape[0]), fmt.Sprintf("Epoch %v/%v", i+1, 3))
for j := 0; j < train.A.shape[0]; j += batchSize {
xBatch := train.A.Slice(j, j+batchSize)
yBatch := train.B.Slice(j, j+batchSize)

yPred := model.Forward(xBatch)
loss := SoftmaxCrossEntropy(yPred, yBatch)

optimizer.ZeroGrad()
loss.Backward()

optimizer.Step()
sumLoss += loss.data[0]
bar.Add(batchSize)

Check failure on line 240 in common/nn/nn_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `bar.Add` is not checked (errcheck)
}
sumLoss /= float32(train.A.shape[0] / batchSize)
bar.Finish()

Check failure on line 243 in common/nn/nn_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `bar.Finish` is not checked (errcheck)
}
assert.Less(t, sumLoss, float32(0.4))

testPred := model.Forward(test.A)
var precision float32
for i, gt := range test.B.data {
if testPred.Slice(i, i+1).argmax()[1] == int(gt) {
precision += 1
}
}
precision /= float32(len(test.B.data))
assert.Greater(t, float64(precision), 0.92)
}
4 changes: 2 additions & 2 deletions common/nn/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,8 @@ func (r *relu) forward(inputs ...*Tensor) *Tensor {
}

func (r *relu) backward(dy *Tensor) []*Tensor {
dx := dy.clone()
dx.maximum(NewScalar(0))
x := r.inputs[0]
dx := x.clone().gt(NewScalar(0)).mul(dy)
return []*Tensor{dx}
}

Expand Down
3 changes: 0 additions & 3 deletions common/nn/op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ func TestDiv(t *testing.T) {
assert.InDeltaSlice(t, []float32{0.5, 2.0 / 3.0, 0.75, 4.0 / 5.0, 5.0 / 6.0, 6.0 / 7.0}, z.data, 1e-6)

// Test gradient
x = Rand(2, 3).RequireGrad()
y = Rand(2, 3).RequireGrad()
z = Div(x, y)
z.Backward()
dx := numericalDiff(func(x *Tensor) *Tensor { return Div(x, y) }, x)
allClose(t, x.grad, dx)
Expand Down
64 changes: 56 additions & 8 deletions common/nn/tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ package nn
import (
"container/heap"
"fmt"
"math"
"math/rand"
"strings"

"github.com/chewxy/math32"
mapset "github.com/deckarep/golang-set/v2"
"github.com/google/uuid"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/base/floats"
"golang.org/x/exp/slices"
"math"
"math/rand"
"strings"
)

type Tensor struct {
Expand Down Expand Up @@ -94,6 +95,21 @@ func Rand(shape ...int) *Tensor {
}
}

func Normal(mean, std float32, shape ...int) *Tensor {
n := 1
for _, s := range shape {
n *= s
}
data := make([]float32, n)
for i := range data {
data[i] = float32(rand.NormFloat64())*std + mean
}
return &Tensor{
data: data,
shape: shape,
}
}

// Ones creates a tensor filled with ones.
func Ones(shape ...int) *Tensor {
n := 1
Expand Down Expand Up @@ -590,6 +606,27 @@ func (t *Tensor) maximum(other *Tensor) {
}
}

func (t *Tensor) gt(other *Tensor) *Tensor {
if other.IsScalar() {
for i := range t.data {
if t.data[i] > other.data[0] {
t.data[i] = 1
} else {
t.data[i] = 0
}
}
} else {
for i := range t.data {
if t.data[i] > other.data[i] {
t.data[i] = 1
} else {
t.data[i] = 0
}
}
}
return t
}

func (t *Tensor) transpose() *Tensor {

Check failure on line 630 in common/nn/tensor.go

View workflow job for this annotation

GitHub Actions / lint

func `(*Tensor).transpose` is unused (unused)
if len(t.shape) < 2 {
panic("transpose requires at least 2-D tensor")
Expand Down Expand Up @@ -694,13 +731,24 @@ func (t *Tensor) sum(axis int, keepDim bool) *Tensor {
}
}

func (t *Tensor) hasNaN() bool {
for i := range t.data {
if math32.IsNaN(t.data[i]) {
return true
func (t *Tensor) argmax() []int {
if len(t.data) == 0 {
return nil
}
maxValue := t.data[0]
maxIndex := 0
for i := 1; i < len(t.data); i++ {
if t.data[i] > maxValue {
maxValue = t.data[i]
maxIndex = i
}
}
return false
indices := make([]int, len(t.shape))
for i := len(t.shape) - 1; i >= 0; i-- {
indices[i] = maxIndex % t.shape[i]
maxIndex /= t.shape[i]
}
return indices
}

func NormalInit(t *Tensor, mean, std float32) {
Expand Down
13 changes: 6 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de
github.com/benhoyt/goawk v1.20.0
github.com/bits-and-blooms/bitset v1.2.1
github.com/chewxy/math32 v1.10.1
github.com/chewxy/math32 v1.11.1
github.com/coreos/go-oidc/v3 v3.11.0
github.com/deckarep/golang-set/v2 v2.3.1
github.com/emicklei/go-restful-openapi/v2 v2.9.0
Expand Down Expand Up @@ -40,7 +40,7 @@ require (
github.com/redis/go-redis/extra/redisotel/v9 v9.5.3
github.com/redis/go-redis/v9 v9.7.0
github.com/samber/lo v1.38.1
github.com/schollz/progressbar/v3 v3.9.0
github.com/schollz/progressbar/v3 v3.17.1
github.com/sclevine/yj v0.0.0-20210612025309-737bdf40a5d1
github.com/spf13/cobra v1.5.0
github.com/spf13/pflag v1.0.5
Expand Down Expand Up @@ -126,8 +126,7 @@ require (
github.com/leodido/go-urn v1.4.0 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-isatty v0.0.16 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
Expand All @@ -145,7 +144,7 @@ require (
github.com/prometheus/procfs v0.8.0 // indirect
github.com/redis/go-redis/extra/rediscmd/v9 v9.5.3 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.3.4 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/shopspring/decimal v1.3.1 // indirect
github.com/spf13/afero v1.9.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
Expand All @@ -165,8 +164,8 @@ require (
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/term v0.25.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/term v0.27.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df // indirect
Expand Down
Loading

0 comments on commit b5d6890

Please sign in to comment.