From 8d6638b74dc7c33febd3f1b70a3fd71379d6802f Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 18 Jan 2025 12:27:02 +0800 Subject: [PATCH] implement item to item recommendation (#905) --- common/ann/ann.go | 4 +- common/ann/ann_test.go | 97 ++++++++++++- common/ann/bruteforce.go | 19 +-- common/ann/hnsw.go | 20 +-- common/dataset/dataset_test.go | 14 -- .../dataset.go => datautil/datautil.go} | 4 +- common/datautil/datautil_test.go | 28 ++++ common/nn/nn_test.go | 6 +- config/config.go | 15 ++ config/config_test.go | 18 +++ dataset/dataset.go | 81 +++++++++++ dataset/dataset_test.go | 49 +++++++ go.mod | 2 +- go.sum | 4 +- logics/item_to_item.go | 109 +++++++++++++++ logics/item_to_item_test.go | 104 ++++++++++++++ master/tasks.go | 129 ++++++++++++++++-- master/tasks_test.go | 12 +- server/rest.go | 18 +++ server/rest_test.go | 1 + storage/cache/database.go | 16 ++- 21 files changed, 670 insertions(+), 80 deletions(-) delete mode 100644 common/dataset/dataset_test.go rename common/{dataset/dataset.go => datautil/datautil.go} (98%) create mode 100644 common/datautil/datautil_test.go create mode 100644 dataset/dataset.go create mode 100644 dataset/dataset_test.go create mode 100644 logics/item_to_item.go create mode 100644 logics/item_to_item_test.go diff --git a/common/ann/ann.go b/common/ann/ann.go index ec04acdcc..86e8f5b1e 100644 --- a/common/ann/ann.go +++ b/common/ann/ann.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package search +package ann import ( "github.com/samber/lo" @@ -21,5 +21,5 @@ import ( type Index interface { Add(v []float32) (int, error) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) - SearchVector(q []float32, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) + SearchVector(q []float32, k int, prune0 bool) []lo.Tuple2[int, float32] } diff --git a/common/ann/ann_test.go b/common/ann/ann_test.go index 9e237692b..c4d22871e 100644 --- a/common/ann/ann_test.go +++ b/common/ann/ann_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package search +package ann import ( "bufio" @@ -20,7 +20,7 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base/floats" - "github.com/zhenghaoz/gorse/common/dataset" + "github.com/zhenghaoz/gorse/common/datautil" "github.com/zhenghaoz/gorse/common/util" "os" "path/filepath" @@ -57,7 +57,7 @@ type MNIST struct { func mnist() (*MNIST, error) { // Download and unzip dataset - path, err := dataset.DownloadAndUnzip("mnist") + path, err := datautil.DownloadAndUnzip("mnist") if err != nil { return nil, err } @@ -136,13 +136,96 @@ func TestMNIST(t *testing.T) { // Test search r := 0.0 for _, image := range dat.TestImages[:testSize] { - gt, err := bf.SearchVector(image, 100, false) - assert.NoError(t, err) + gt := bf.SearchVector(image, 100, false) assert.Len(t, gt, 100) - scores, err := hnsw.SearchVector(image, 100, false) - assert.NoError(t, err) + scores := hnsw.SearchVector(image, 100, false) assert.Len(t, scores, 100) r += recall(gt, scores) } + r /= float64(testSize) assert.Greater(t, r, 0.99) } + +func movieLens() ([][]int, error) { + // Download and unzip dataset + path, err := datautil.DownloadAndUnzip("ml-1m") + if err != nil { + return nil, err + } + // Open file + f, err := os.Open(filepath.Join(path, "train.txt")) + if err != nil { + return nil, err + } + defer f.Close() + // Read data line by line + movies := make([][]int, 0) + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + splits := strings.Split(line, "\t") + userId, err := strconv.Atoi(splits[0]) + if err != nil { + return nil, err + } + movieId, err := strconv.Atoi(splits[1]) + if err != nil { + return nil, err + } + for movieId >= len(movies) { + movies = append(movies, make([]int, 0)) + } + movies[movieId] = append(movies[movieId], userId) + } + return movies, nil +} + +func jaccard(a, b []int) float32 { + var i, j, intersection int + for i < len(a) && j < len(b) { + if a[i] == b[j] { + intersection++ + i++ + j++ + } else if a[i] < b[j] { + i++ + } else { + j++ + } + } + if len(a)+len(b)-intersection == 0 { + return 1 + } + return 1 - float32(intersection)/float32(len(a)+len(b)-intersection) +} + +func TestMovieLens(t *testing.T) { + movies, err := movieLens() + assert.NoError(t, err) + + // Create brute-force index + bf := NewBruteforce(jaccard) + for _, movie := range movies { + _, err := bf.Add(movie) + assert.NoError(t, err) + } + + // Create HNSW index + hnsw := NewHNSW(jaccard) + for _, movie := range movies { + _, err := hnsw.Add(movie) + assert.NoError(t, err) + } + + // Test search + r := 0.0 + for i := range movies[:testSize] { + gt, err := bf.SearchIndex(i, 100, false) + assert.NoError(t, err) + scores, err := hnsw.SearchIndex(i, 100, false) + assert.NoError(t, err) + r += recall(gt, scores) + } + r /= float64(testSize) + assert.Greater(t, r, 0.98) +} diff --git a/common/ann/bruteforce.go b/common/ann/bruteforce.go index d0c0a069b..1869da9ae 100644 --- a/common/ann/bruteforce.go +++ b/common/ann/bruteforce.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package search +package ann import ( "github.com/juju/errors" @@ -23,7 +23,6 @@ import ( // Bruteforce is a naive implementation of vector index. type Bruteforce[T any] struct { distanceFunc func(a, b []T) float32 - dimension int vectors [][]T } @@ -32,15 +31,9 @@ func NewBruteforce[T any](distanceFunc func(a, b []T) float32) *Bruteforce[T] { } func (b *Bruteforce[T]) Add(v []T) (int, error) { - // Check dimension - if b.dimension == 0 { - b.dimension = len(v) - } else if b.dimension != len(v) { - return 0, errors.Errorf("dimension mismatch: %v != %v", b.dimension, len(v)) - } // Add vector b.vectors = append(b.vectors, v) - return len(b.vectors) - 1, nil + return len(b.vectors), nil } func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { @@ -62,14 +55,14 @@ func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, flo scores := make([]lo.Tuple2[int, float32], 0) for pq.Len() > 0 { value, score := pq.Pop() - if !prune0 || score < 0 { + if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores, nil } -func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { +func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] { // Search pq := heap.NewPriorityQueue(true) for i, vec := range b.vectors { @@ -82,9 +75,9 @@ func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int scores := make([]lo.Tuple2[int, float32], 0) for pq.Len() > 0 { value, score := pq.Pop() - if !prune0 || score < 0 { + if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } - return scores, nil + return scores } diff --git a/common/ann/hnsw.go b/common/ann/hnsw.go index fe6b74f64..481a17ac2 100644 --- a/common/ann/hnsw.go +++ b/common/ann/hnsw.go @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package search +package ann import ( - "errors" "github.com/chewxy/math32" mapset "github.com/deckarep/golang-set/v2" "github.com/samber/lo" @@ -28,7 +27,6 @@ import ( // HNSW is a vector index based on Hierarchical Navigable Small Worlds. type HNSW[T any] struct { distanceFunc func(a, b []T) float32 - dimension int vectors [][]T bottomNeighbors []*heap.PriorityQueue upperNeighbors []map[int32]*heap.PriorityQueue @@ -53,12 +51,6 @@ func NewHNSW[T any](distanceFunc func(a, b []T) float32) *HNSW[T] { } func (h *HNSW[T]) Add(v []T) (int, error) { - // Check dimension - if h.dimension == 0 { - h.dimension = len(v) - } else if h.dimension != len(v) { - return 0, errors.New("dimension mismatch") - } // Add vector h.vectors = append(h.vectors, v) h.bottomNeighbors = append(h.bottomNeighbors, heap.NewPriorityQueue(false)) @@ -66,28 +58,28 @@ func (h *HNSW[T]) Add(v []T) (int, error) { return len(h.vectors) - 1, nil } -func (h *HNSW[T]) Search(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { +func (h *HNSW[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { w := h.knnSearch(h.vectors[q], k, h.efSearchValue(k)) scores := make([]lo.Tuple2[int, float32], 0) for w.Len() > 0 { value, score := w.Pop() - if !prune0 || score < 0 { + if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } return scores, nil } -func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) { +func (h *HNSW[T]) SearchVector(q []T, k int, prune0 bool) []lo.Tuple2[int, float32] { w := h.knnSearch(q, k, h.efSearchValue(k)) scores := make([]lo.Tuple2[int, float32], 0) for w.Len() > 0 { value, score := w.Pop() - if !prune0 || score < 0 { + if !prune0 || score > 0 { scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score}) } } - return scores, nil + return scores } func (h *HNSW[T]) knnSearch(q []T, k, ef int) *heap.PriorityQueue { diff --git a/common/dataset/dataset_test.go b/common/dataset/dataset_test.go deleted file mode 100644 index 78ef60ccd..000000000 --- a/common/dataset/dataset_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package dataset - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestLoadIris(t *testing.T) { - data, target, err := LoadIris() - assert.NoError(t, err) - assert.Len(t, data, 150) - assert.Len(t, data[0], 4) - assert.Len(t, target, 150) -} diff --git a/common/dataset/dataset.go b/common/datautil/datautil.go similarity index 98% rename from common/dataset/dataset.go rename to common/datautil/datautil.go index 3fc83bdda..e92e74cec 100644 --- a/common/dataset/dataset.go +++ b/common/datautil/datautil.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package dataset +package datautil import ( "archive/zip" @@ -85,7 +85,7 @@ func DownloadAndUnzip(name string) (string, error) { path := filepath.Join(datasetDir, name) if _, err := os.Stat(path); os.IsNotExist(err) { zipFileName, _ := downloadFromUrl(url, tempDir) - if _, err := unzip(zipFileName, path); err != nil { + if _, err := unzip(zipFileName, datasetDir); err != nil { return "", err } } diff --git a/common/datautil/datautil_test.go b/common/datautil/datautil_test.go new file mode 100644 index 000000000..f4a61dd2f --- /dev/null +++ b/common/datautil/datautil_test.go @@ -0,0 +1,28 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datautil + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLoadIris(t *testing.T) { + data, target, err := LoadIris() + assert.NoError(t, err) + assert.Len(t, data, 150) + assert.Len(t, data[0], 4) + assert.Len(t, target, 150) +} diff --git a/common/nn/nn_test.go b/common/nn/nn_test.go index 917108860..2e7220a28 100644 --- a/common/nn/nn_test.go +++ b/common/nn/nn_test.go @@ -31,7 +31,7 @@ import ( "github.com/samber/lo" "github.com/schollz/progressbar/v3" "github.com/stretchr/testify/assert" - "github.com/zhenghaoz/gorse/common/dataset" + "github.com/zhenghaoz/gorse/common/datautil" "github.com/zhenghaoz/gorse/common/util" ) @@ -91,7 +91,7 @@ func TestNeuralNetwork(t *testing.T) { func iris() (*Tensor, *Tensor, error) { // Download dataset - path, err := dataset.DownloadAndUnzip("iris") + path, err := datautil.DownloadAndUnzip("iris") if err != nil { return nil, nil, err } @@ -153,7 +153,7 @@ func TestIris(t *testing.T) { func mnist() (lo.Tuple2[*Tensor, *Tensor], lo.Tuple2[*Tensor, *Tensor], error) { var train, test lo.Tuple2[*Tensor, *Tensor] // Download and unzip dataset - path, err := dataset.DownloadAndUnzip("mnist") + path, err := datautil.DownloadAndUnzip("mnist") if err != nil { return train, test, err } diff --git a/config/config.go b/config/config.go index 2145c6d29..e9f990a35 100644 --- a/config/config.go +++ b/config/config.go @@ -116,6 +116,7 @@ type RecommendConfig struct { DataSource DataSourceConfig `mapstructure:"data_source"` NonPersonalized []NonPersonalizedConfig `mapstructure:"non-personalized" validate:"dive"` Popular PopularConfig `mapstructure:"popular"` + ItemToItem []ItemToItemConfig `mapstructure:"item-to-item" validate:"dive"` UserNeighbors NeighborsConfig `mapstructure:"user_neighbors"` ItemNeighbors NeighborsConfig `mapstructure:"item_neighbors"` Collaborative CollaborativeConfig `mapstructure:"collaborative"` @@ -148,6 +149,20 @@ type NeighborsConfig struct { IndexFitEpoch int `mapstructure:"index_fit_epoch" validate:"gt=0"` } +type ItemToItemConfig struct { + Name string `mapstructure:"name" json:"name"` + Type string `mapstructure:"type" json:"type" validate:"oneof=embedding"` + Column string `mapstructure:"column" json:"column" validate:"item_expr"` +} + +func (config *ItemToItemConfig) Hash() string { + hash := md5.New() + hash.Write([]byte(config.Name)) + hash.Write([]byte(config.Type)) + hash.Write([]byte(config.Column)) + return string(hash.Sum(nil)) +} + type CollaborativeConfig struct { ModelFitPeriod time.Duration `mapstructure:"model_fit_period" validate:"gt=0"` ModelSearchPeriod time.Duration `mapstructure:"model_search_period" validate:"gt=0"` diff --git a/config/config_test.go b/config/config_test.go index d525275bd..01ef96c6d 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -492,3 +492,21 @@ func TestConfig_OfflineRecommendDigest(t *testing.T) { cfg2.Recommend.Replacement.PositiveReplacementDecay = 0.2 assert.Equal(t, cfg1.OfflineRecommendDigest(), cfg2.OfflineRecommendDigest()) } + +func TestItemToItemConfig_Hash(t *testing.T) { + a := ItemToItemConfig{} + b := ItemToItemConfig{} + assert.Equal(t, a.Hash(), b.Hash()) + + a = ItemToItemConfig{Name: "a"} + b = ItemToItemConfig{Name: "b"} + assert.NotEqual(t, a.Hash(), b.Hash()) + + a = ItemToItemConfig{Type: "a"} + b = ItemToItemConfig{Type: "b"} + assert.NotEqual(t, a.Hash(), b.Hash()) + + a = ItemToItemConfig{Column: "a"} + b = ItemToItemConfig{Column: "b"} + assert.NotEqual(t, a.Hash(), b.Hash()) +} diff --git a/dataset/dataset.go b/dataset/dataset.go new file mode 100644 index 000000000..55fb225c2 --- /dev/null +++ b/dataset/dataset.go @@ -0,0 +1,81 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dataset + +import ( + "github.com/samber/lo" + "github.com/zhenghaoz/gorse/storage/data" + "time" +) + +type Dataset struct { + timestamp time.Time + items []data.Item +} + +func NewDataset(timestamp time.Time, itemCount int) *Dataset { + return &Dataset{ + timestamp: timestamp, + items: make([]data.Item, 0, itemCount), + } +} + +func (d *Dataset) GetTimestamp() time.Time { + return d.timestamp +} + +func (d *Dataset) GetItems() []data.Item { + return d.items +} + +func (d *Dataset) AddItem(item data.Item) { + d.items = append(d.items, data.Item{ + ItemId: item.ItemId, + IsHidden: item.IsHidden, + Categories: item.Categories, + Timestamp: item.Timestamp, + Labels: d.processLabels(item.Labels), + Comment: item.Comment, + }) +} + +func (d *Dataset) processLabels(labels any) any { + switch typed := labels.(type) { + case map[string]any: + o := make(map[string]any) + for k, v := range typed { + o[k] = d.processLabels(v) + } + return o + case []any: + if isSliceOf[float64](typed) { + return lo.Map(typed, func(e any, _ int) float32 { + return float32(e.(float64)) + }) + } + return typed + default: + return labels + } +} + +func isSliceOf[T any](v []any) bool { + for _, e := range v { + if _, ok := e.(T); !ok { + return false + } + } + return true +} diff --git a/dataset/dataset_test.go b/dataset/dataset_test.go new file mode 100644 index 000000000..23bfbdaa5 --- /dev/null +++ b/dataset/dataset_test.go @@ -0,0 +1,49 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dataset + +import ( + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/storage/data" + "testing" + "time" +) + +func TestDataset_AddItem(t *testing.T) { + dataSet := NewDataset(time.Now(), 1) + dataSet.AddItem(data.Item{ + ItemId: "1", + IsHidden: false, + Categories: []string{"a", "b"}, + Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + Labels: map[string]any{ + "a": 1, + "embedded": []any{1.1, 2.2, 3.3}, + }, + Comment: "comment", + }) + assert.Len(t, dataSet.GetItems(), 1) + assert.Equal(t, data.Item{ + ItemId: "1", + IsHidden: false, + Categories: []string{"a", "b"}, + Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + Labels: map[string]any{ + "a": 1, + "embedded": []float32{1.1, 2.2, 3.3}, + }, + Comment: "comment", + }, dataSet.GetItems()[0]) +} diff --git a/go.mod b/go.mod index f392cc08d..dac11763d 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606 + github.com/gorse-io/dashboard v0.0.0-20250101053324-8d40fd3b3a1c github.com/gorse-io/gorse-go v0.5.0-alpha.1 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 diff --git a/go.sum b/go.sum index 38b41b338..91ddb90b9 100644 --- a/go.sum +++ b/go.sum @@ -314,8 +314,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606 h1:5Vh8xik8c905IYFg66ujt7FuuuPtzSW6e2DRzBUYc58= -github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc= +github.com/gorse-io/dashboard v0.0.0-20250101053324-8d40fd3b3a1c h1:2G3W2QefCQqnAz6UiKFd0SQM5aG8KPPzJT1eakODFMY= +github.com/gorse-io/dashboard v0.0.0-20250101053324-8d40fd3b3a1c/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/gorse-go v0.5.0-alpha.1 h1:QBWKGAbSKNAWnieXVIdQiE0lLGvKXfFFAFPOQEkPW/E= diff --git a/logics/item_to_item.go b/logics/item_to_item.go new file mode 100644 index 000000000..6baccbc7f --- /dev/null +++ b/logics/item_to_item.go @@ -0,0 +1,109 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logics + +import ( + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" + "github.com/samber/lo" + "github.com/zhenghaoz/gorse/base/floats" + "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/common/ann" + "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/storage/cache" + "github.com/zhenghaoz/gorse/storage/data" + "go.uber.org/zap" + "time" +) + +type ItemToItem struct { + name string + n int + timestamp time.Time + columnFunc *vm.Program + index *ann.HNSW[float32] + items []string + dimension int +} + +func NewItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (*ItemToItem, error) { + // Compile column expression + columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ + "item": data.Item{}, + })) + if err != nil { + return nil, err + } + return &ItemToItem{ + name: cfg.Name, + n: n, + timestamp: timestamp, + columnFunc: columnFunc, + index: ann.NewHNSW[float32](floats.Euclidean), + }, nil +} + +func (i *ItemToItem) Push(item data.Item) { + // Check if hidden + if item.IsHidden { + return + } + // Evaluate filter function + result, err := expr.Run(i.columnFunc, map[string]any{ + "item": item, + }) + if err != nil { + log.Logger().Error("failed to evaluate column expression", + zap.Any("item", item), zap.Error(err)) + return + } + // Check column type + v, ok := result.([]float32) + if !ok { + log.Logger().Error("invalid column type", zap.Any("column", result)) + return + } + // Check dimension + if i.dimension == 0 && len(v) > 0 { + i.dimension = len(v) + } else if i.dimension != len(v) { + log.Logger().Error("invalid column dimension", zap.Int("dimension", len(v))) + return + } + // Push item + i.items = append(i.items, item.ItemId) + _, err = i.index.Add(v) + if err != nil { + log.Logger().Error("failed to add item to index", zap.Error(err)) + return + } +} + +func (i *ItemToItem) PopAll(callback func(itemId string, score []cache.Score)) { + for index, item := range i.items { + scores, err := i.index.SearchIndex(index, i.n+1, true) + if err != nil { + log.Logger().Error("failed to search index", zap.Error(err)) + return + } + callback(item, lo.Map(scores, func(v lo.Tuple2[int, float32], _ int) cache.Score { + return cache.Score{ + Id: i.items[v.A], + Score: float64(v.B), + Timestamp: i.timestamp, + } + })) + } +} diff --git a/logics/item_to_item_test.go b/logics/item_to_item_test.go new file mode 100644 index 000000000..f2791aae7 --- /dev/null +++ b/logics/item_to_item_test.go @@ -0,0 +1,104 @@ +// Copyright 2024 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logics + +import ( + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/storage/cache" + "github.com/zhenghaoz/gorse/storage/data" + "strconv" + "testing" + "time" +) + +func TestColumnFunc(t *testing.T) { + item2item, err := NewItemToItem(config.ItemToItemConfig{ + Column: "item.Labels.description", + }, 10, time.Now()) + assert.NoError(t, err) + + // Push success + item2item.Push(data.Item{ + ItemId: "1", + Labels: map[string]any{ + "description": []float32{0.1, 0.2, 0.3}, + }, + }) + assert.Len(t, item2item.items, 1) + + // Hidden + item2item.Push(data.Item{ + ItemId: "2", + IsHidden: true, + Labels: map[string]any{ + "description": []float32{0.1, 0.2, 0.3}, + }, + }) + assert.Len(t, item2item.items, 1) + + // Dimension does not match + item2item.Push(data.Item{ + ItemId: "1", + Labels: map[string]any{ + "description": []float32{0.1, 0.2}, + }, + }) + assert.Len(t, item2item.items, 1) + + // Type does not match + item2item.Push(data.Item{ + ItemId: "1", + Labels: map[string]any{ + "description": "hello", + }, + }) + assert.Len(t, item2item.items, 1) + + // Column does not exist + item2item.Push(data.Item{ + ItemId: "2", + Labels: []float32{0.1, 0.2, 0.3}, + }) + assert.Len(t, item2item.items, 1) +} + +func TestEmbedding(t *testing.T) { + timestamp := time.Now() + item2item, err := NewItemToItem(config.ItemToItemConfig{ + Column: "item.Labels.description", + }, 10, timestamp) + assert.NoError(t, err) + + for i := 0; i < 100; i++ { + item2item.Push(data.Item{ + ItemId: strconv.Itoa(i), + Labels: map[string]any{ + "description": []float32{0.1 * float32(i), 0.2 * float32(i), 0.3 * float32(i)}, + }, + }) + } + + var scores []cache.Score + item2item.PopAll(func(itemId string, score []cache.Score) { + if itemId == "0" { + scores = score + } + }) + assert.Len(t, scores, 10) + for i := 1; i <= 10; i++ { + assert.Equal(t, strconv.Itoa(i), scores[i-1].Id) + } +} diff --git a/master/tasks.go b/master/tasks.go index 926786b71..271d12e4a 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -17,7 +17,6 @@ package master import ( "context" "fmt" - "github.com/zhenghaoz/gorse/logics" "sort" "strings" "sync" @@ -37,6 +36,8 @@ import ( "github.com/zhenghaoz/gorse/base/sizeof" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/dataset" + "github.com/zhenghaoz/gorse/logics" "github.com/zhenghaoz/gorse/model/click" "github.com/zhenghaoz/gorse/model/ranking" "github.com/zhenghaoz/gorse/storage/cache" @@ -93,12 +94,13 @@ func (m *Master) runLoadDatasetTask() error { zap.Uint("item_ttl", m.Config.Recommend.DataSource.ItemTTL), zap.Uint("feedback_ttl", m.Config.Recommend.DataSource.PositiveFeedbackTTL)) evaluator := NewOnlineEvaluator() - rankingDataset, clickDataset, err := m.LoadDataFromDatabase(ctx, m.DataClient, + rankingDataset, clickDataset, dataSet, err := m.LoadDataFromDatabase(ctx, m.DataClient, m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes, m.Config.Recommend.DataSource.ItemTTL, m.Config.Recommend.DataSource.PositiveFeedbackTTL, - evaluator, nonPersonalizedRecommenders) + evaluator, + nonPersonalizedRecommenders) if err != nil { return errors.Trace(err) } @@ -201,6 +203,10 @@ func (m *Master) runLoadDatasetTask() error { MemoryInUseBytesVec.WithLabelValues("ranking_train_set").Set(float64(sizeof.DeepSize(m.clickTrainSet))) MemoryInUseBytesVec.WithLabelValues("ranking_test_set").Set(float64(sizeof.DeepSize(m.clickTestSet))) + if err = m.updateItemToItem(dataSet); err != nil { + log.Logger().Error("failed to update item-to-item recommendation", zap.Error(err)) + } + LoadDatasetTotalSeconds.Set(time.Since(initialStartTime).Seconds()) return nil } @@ -1407,21 +1413,23 @@ func (m *Master) LoadDataFromDatabase( itemTTL, positiveFeedbackTTL uint, evaluator *OnlineEvaluator, nonPersonalizedRecommenders []*logics.NonPersonalized, -) (rankingDataset *ranking.DataSet, clickDataset *click.Dataset, err error) { +) (rankingDataset *ranking.DataSet, clickDataset *click.Dataset, dataSet *dataset.Dataset, err error) { // Estimate the number of users, items, and feedbacks estimatedNumUsers, err := m.DataClient.CountUsers(context.Background()) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } estimatedNumItems, err := m.DataClient.CountItems(context.Background()) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } estimatedNumFeedbacks, err := m.DataClient.CountFeedback(context.Background()) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } + dataSet = dataset.NewDataset(time.Now(), estimatedNumItems) + newCtx, span := progress.Start(ctx, "LoadDataFromDatabase", estimatedNumUsers+estimatedNumItems+estimatedNumFeedbacks) defer span.End() @@ -1486,7 +1494,7 @@ func (m *Master) LoadDataFromDatabase( span.Add(len(users)) } if err = <-errChan; err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } rankingDataset.NumUserLabels = userLabelIndex.Len() log.Logger().Debug("pulled users from database", @@ -1542,11 +1550,12 @@ func (m *Master) LoadDataFromDatabase( if item.IsHidden { // set hidden flag rankingDataset.HiddenItems[itemIndex] = true } + dataSet.AddItem(item) } span.Add(len(batchItems)) } if err = <-errChan; err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } rankingDataset.NumItemLabels = itemLabelIndex.Len() log.Logger().Debug("pulled items from database", @@ -1651,7 +1660,7 @@ func (m *Master) LoadDataFromDatabase( return nil }) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } log.Logger().Debug("pulled positive feedback from database", zap.Int("n_positive_feedback", posFeedbackCount), @@ -1701,7 +1710,7 @@ func (m *Master) LoadDataFromDatabase( return nil }) if err != nil { - return nil, nil, errors.Trace(err) + return nil, nil, nil, errors.Trace(err) } log.Logger().Debug("pulled negative feedback from database", zap.Int("n_negative_feedback", int(negativeFeedbackCount)), @@ -1750,5 +1759,101 @@ func (m *Master) LoadDataFromDatabase( zap.Int("n_valid_negative", clickDataset.NegativeCount), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("create_ranking_dataset").Set(time.Since(start).Seconds()) - return rankingDataset, clickDataset, nil + return rankingDataset, clickDataset, dataSet, nil +} + +func (m *Master) updateItemToItem(dataset *dataset.Dataset) error { + ctx, span := m.tracer.Start(context.Background(), "Generate item-to-item recommendation", + len(dataset.GetItems())*len(m.Config.Recommend.ItemToItem)*2) + defer span.End() + + // Build item-to-item recommenders + itemToItemRecommenders := make([]*logics.ItemToItem, 0, len(m.Config.Recommend.ItemToItem)) + for _, cfg := range m.Config.Recommend.ItemToItem { + recommender, err := logics.NewItemToItem(cfg, m.Config.Recommend.CacheSize, dataset.GetTimestamp()) + if err != nil { + return errors.Trace(err) + } + itemToItemRecommenders = append(itemToItemRecommenders, recommender) + } + + // Push items to item-to-item recommenders + for _, item := range dataset.GetItems() { + if !item.IsHidden { + for _, recommender := range itemToItemRecommenders { + recommender.Push(item) + span.Add(1) + } + } + } + + // Save item-to-item recommendations to cache + for i, recommender := range itemToItemRecommenders { + recommender.PopAll(func(itemId string, score []cache.Score) { + itemToItemConfig := m.Config.Recommend.ItemToItem[i] + if m.needUpdateItemToItem(itemId, m.Config.Recommend.ItemToItem[i]) { + log.Logger().Debug("update item-to-item recommendation", + zap.String("item_id", itemId), + zap.String("name", itemToItemConfig.Name), + zap.Int("n_recommendations", len(score))) + // Save item-to-item recommendation to cache + if err := m.CacheClient.AddScores(ctx, cache.ItemToItem, cache.Key(itemToItemConfig.Name, itemId), score); err != nil { + log.Logger().Error("failed to save item-to-item recommendation to cache", + zap.String("item_id", itemId), zap.Error(err)) + return + } + // Save item-to-item digest and last update time to cache + if err := m.CacheClient.Set(ctx, + cache.String(cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, itemId), itemToItemConfig.Hash()), + cache.Time(cache.Key(cache.LastUpdateItemToItemTime, itemToItemConfig.Name, itemId), time.Now()), + ); err != nil { + log.Logger().Error("failed to save item-to-item digest to cache", + zap.String("item_id", itemId), zap.Error(err)) + return + } + } + span.Add(1) + }) + } + return nil +} + +// needUpdateItemToItem checks if item-to-item recommendation needs to be updated. +// 1. The cache is empty. +// 2. The modified time is newer than the last update time. +func (m *Master) needUpdateItemToItem(itemId string, itemToItemConfig config.ItemToItemConfig) bool { + ctx := context.Background() + + // check cache + items, err := m.CacheClient.SearchScores(ctx, cache.ItemToItem, + cache.Key(itemToItemConfig.Name, itemId), nil, 0, -1) + if err != nil { + log.Logger().Error("failed to fetch item-to-item recommendation", + zap.String("item_id", itemId), zap.Error(err)) + return true + } else if len(items) == 0 { + return true + } + + // check digest + digest, err := m.CacheClient.Get(ctx, cache.Key(cache.ItemToItemDigest, itemToItemConfig.Name, itemId)).String() + if err != nil { + if !errors.Is(err, errors.NotFound) { + log.Logger().Error("failed to read item-to-item digest", zap.Error(err)) + } + return true + } + if digest != itemToItemConfig.Hash() { + return true + } + + // check update time + updateTime, err := m.CacheClient.Get(ctx, cache.Key(cache.LastUpdateItemToItemTime, itemToItemConfig.Name, itemId)).Time() + if err != nil { + if !errors.Is(err, errors.NotFound) { + log.Logger().Error("failed to read last update item neighbors time", zap.Error(err)) + } + return true + } + return updateTime.Before(time.Now().Add(-m.Config.Recommend.CacheExpire)) } diff --git a/master/tasks_test.go b/master/tasks_test.go index 51e831d73..1db5d4d7f 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -81,7 +81,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { } // load mock dataset - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset @@ -187,7 +187,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { } // load mock dataset - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset @@ -255,7 +255,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "1"}}, }, true, true, true) s.NoError(err) - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset @@ -316,7 +316,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsBruteForce() { s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset @@ -397,7 +397,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF() { s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset @@ -457,7 +457,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "1", ItemId: "0"}}, }, true, true, true) s.NoError(err) - dataset, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, + dataset, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator(), nil) s.NoError(err) s.rankingTrainSet = dataset diff --git a/server/rest.go b/server/rest.go index 45a5eaca3..5f170f9eb 100644 --- a/server/rest.go +++ b/server/rest.go @@ -484,6 +484,17 @@ func (s *RestServer) CreateWebService() { Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). Returns(http.StatusOK, "OK", []cache.Score{}). Writes([]cache.Score{})) + // Get item-to-item recommendation + ws.Route(ws.GET("/item-to-item/{name}/{item-id}").To(s.getItemToItem). + Doc("Get item-to-item recommendation."). + Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). + Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). + Param(ws.PathParameter("name", "Name of the item-to-item recommendation").DataType("string")). + Param(ws.PathParameter("item-id", "ID of the item to get neighbors").DataType("string")). + Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). + Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). + Returns(http.StatusOK, "OK", []cache.Score{}). + Writes([]cache.Score{})) // Get neighbors ws.Route(ws.GET("/item/{item-id}/neighbors/").To(s.getItemNeighbors). Doc("Get neighbors of a item"). @@ -684,6 +695,12 @@ func (s *RestServer) getNonPersonalized(request *restful.Request, response *rest s.SearchDocuments(cache.NonPersonalized, name, categories, nil, request, response) } +func (s *RestServer) getItemToItem(request *restful.Request, response *restful.Response) { + name := request.PathParameter("name") + itemId := request.PathParameter("item-id") + s.SearchDocuments(cache.ItemToItem, cache.Key(name, itemId), nil, nil, request, response) +} + // get feedback by item-id with feedback type func (s *RestServer) getTypedFeedbackByItem(request *restful.Request, response *restful.Response) { ctx := context.Background() @@ -1979,6 +1996,7 @@ func withWildCard(categories []string) []string { return result } +// ReadCategories tries to read categories from the request. If the category is not found, it returns an empty string. func ReadCategories(request *restful.Request) []string { if pathValue := request.PathParameter("category"); pathValue != "" { return []string{pathValue} diff --git a/server/rest_test.go b/server/rest_test.go index cdfa8f588..ed75f32d6 100644 --- a/server/rest_test.go +++ b/server/rest_test.go @@ -828,6 +828,7 @@ func (suite *ServerTestSuite) TestNonPersonalizedRecommend() { {"PopularItemsCategory", cache.NonPersonalized, cache.Popular, "0", "/api/popular/0"}, {"NonPersonalized", cache.NonPersonalized, "trending", "", "/api/non-personalized/trending"}, {"NonPersonalizedCategory", cache.NonPersonalized, "trending", "0", "/api/non-personalized/trending"}, + {"ItemToItem", cache.ItemToItem, cache.Key("lookalike", "0"), "", "/api/item-to-item/lookalike/0"}, {"Offline Recommend", cache.OfflineRecommend, "0", "", "/api/intermediate/recommend/0"}, {"Offline Recommend in Category", cache.OfflineRecommend, "0", "0", "/api/intermediate/recommend/0/0"}, } diff --git a/storage/cache/database.go b/storage/cache/database.go index 999ee68b1..451c22188 100644 --- a/storage/cache/database.go +++ b/storage/cache/database.go @@ -77,9 +77,11 @@ const ( // Recommendation digest - offline_recommend_digest/{user_id} OfflineRecommendDigest = "offline_recommend_digest" - NonPersonalized = "non-personalized" - Latest = "latest" - Popular = "popular" + NonPersonalized = "non-personalized" + Latest = "latest" + Popular = "popular" + ItemToItem = "item-to-item" + ItemToItemDigest = "item-to-item_digest" // ItemCategories is the set of item categories. The format of key: // Global item categories - item_categories @@ -90,6 +92,7 @@ const ( LastUpdateUserRecommendTime = "last_update_user_recommend_time" // the latest timestamp that a user's recommendation was updated LastUpdateUserNeighborsTime = "last_update_user_neighbors_time" // the latest timestamp that a user's neighbors item was updated LastUpdateItemNeighborsTime = "last_update_item_neighbors_time" // the latest timestamp that an item's neighbors was updated + LastUpdateItemToItemTime = "last_update_item_to_item_time" // the latest timestamp that an item-to-item model was updated // GlobalMeta is global meta information GlobalMeta = "global_meta" @@ -110,7 +113,12 @@ const ( MatchingIndexRecall = "matching_index_recall" ) -var ItemCache = []string{NonPersonalized, ItemNeighbors, OfflineRecommend} +var ItemCache = []string{ + NonPersonalized, + ItemToItem, + ItemNeighbors, + OfflineRecommend, +} var ( ErrObjectNotExist = errors.NotFoundf("object")