diff --git a/README.md b/README.md index 85e2272..53bee29 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ # go-YouTokenToMe -go-YouTokenToMe is a Go port of [YoutTokenToMe](https://github.com/VKCOM/YouTokenToMe) - a computationally efficient implementation of Byte Pair Encoding [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162/)]. Only inference is supported, no training. \ No newline at end of file +go-YouTokenToMe is a Go port of [YoutTokenToMe](https://github.com/VKCOM/YouTokenToMe) - a computationally efficient implementation of Byte Pair Encoding [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162/)]. Only inference is supported, no training. + +## Usage example +```go +file, err := os.Open("data/yttm.model") +if err != nil { + fmt.Println(err) + return +} +defer file.Close() + +r := io.Reader(file) + +m, err := bpe.ReadModel(r) +if err != nil { + panic(err) +} +config := bpe.NewConfig(false, false, false) +fmt.Println(m.EncodeSentence("мама мыла раму", *config)) +``` \ No newline at end of file diff --git a/bpe.go b/bpe.go index 74f5695..b1a9877 100644 --- a/bpe.go +++ b/bpe.go @@ -5,6 +5,7 @@ import ( "container/heap" "encoding/binary" "errors" + "fmt" "io" "strconv" "strings" @@ -102,16 +103,38 @@ func (s specialTokens) toBinary() []byte { return bytesArray } -func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) { +func rowToSpecialTokens(row string) (specialTokens, error) { var s specialTokens - if len(bytesArray) < 16 { - logrus.Error("Bytes array length is too small") - return s, errors.New("bytes array is too small") + rowSplitted := strings.Fields(row) + if len(rowSplitted) != 4 { + logrus.Errorf("String slice with len %d != 4", len(rowSplitted)) + return s, errors.New("string slice is wrong") } - s.unk = int32(binary.BigEndian.Uint32(bytesArray)) - s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:])) - s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:])) - s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:])) + unk, err := strconv.Atoi(rowSplitted[0]) + if err != nil { + logrus.Error("Broken input:", err) + return s, err + } + pad, err := strconv.Atoi(rowSplitted[1]) + if err != nil { + logrus.Error("Broken input:", err) + return s, err + } + bos, err := strconv.Atoi(rowSplitted[2]) + if err != nil { + logrus.Error("Broken input:", err) + return s, err + } + eos, err := strconv.Atoi(rowSplitted[2]) + if err != nil { + logrus.Error("Broken input:", err) + return s, err + } + + s.unk = int32(unk) + s.pad = int32(pad) + s.bos = int32(bos) + s.eos = int32(eos) return s, nil } @@ -123,100 +146,139 @@ func (r rule) toBinary() []byte { return bytesArray } -func binaryToRule(bytesArray []byte) (rule, error) { +func rowToRule(row string) (rule, error) { + rowSplitted := strings.Fields(row) var r rule - if len(bytesArray) < 12 { - logrus.Error("Bytes array length is too small") - return r, errors.New("bytes array is too small") + if len(rowSplitted) != 3 { + logrus.Errorf("String slice with len %d != 3", len(rowSplitted)) + return r, errors.New("string slice is wrong") + } + rLeft, err := strconv.Atoi(rowSplitted[0]) + if err != nil { + logrus.Error("Broken input:", err) + return r, err + } + rRight, err := strconv.Atoi(rowSplitted[1]) + if err != nil { + logrus.Error("Broken input:", err) + return r, err + } + rRes, err := strconv.Atoi(rowSplitted[2]) + if err != nil { + logrus.Error("Broken input:", err) + return r, err } - r.left = TokenID(binary.BigEndian.Uint32(bytesArray)) - r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:])) - r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:])) + + r.left = TokenID(rLeft) + r.right = TokenID(rRight) + r.result = TokenID(rRes) return r, nil } // ReadModel loads the BPE model from the binary dump func ReadModel(reader io.Reader) (*Model, error) { - buf := make([]byte, 4) + + scanner := bufio.NewScanner(reader) var nChars, nRules int - if _, err := io.ReadFull(reader, buf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - nChars = int(binary.BigEndian.Uint32(buf)) - if _, err := io.ReadFull(reader, buf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - nRules = int(binary.BigEndian.Uint32(buf)) + var char rune + var charID TokenID + var row string + var err error - model := newModel(nRules) + model := &Model{} minCharID := TokenID(0) - for i := 0; i < nChars; i++ { - var char rune - var charID TokenID - if _, err := io.ReadFull(reader, buf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - char = rune(binary.BigEndian.Uint32(buf)) - if _, err := io.ReadFull(reader, buf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - charID = TokenID(binary.BigEndian.Uint32(buf)) - model.char2id[char] = charID - model.id2char[charID] = char - model.recipe[charID] = EncodedString{charID} - model.revRecipe[string(char)] = charID - if charID < minCharID || minCharID == 0 { - minCharID = charID - model.spaceID = charID - } - } - ruleBuf := make([]byte, 12) - for i := 0; i < nRules; i++ { - if _, err := io.ReadFull(reader, ruleBuf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - rule, err := binaryToRule(ruleBuf) - if err != nil { - return model, err + + i := 0 + j := 0 + + for scanner.Scan() { + row = scanner.Text() + if i == 0 { + nChars, err = strconv.Atoi(strings.Fields(row)[0]) + if err != nil { + logrus.Error("Broken input:", err) + return &Model{}, err + } + + nRules, err = strconv.Atoi(strings.Fields(row)[1]) + if err != nil { + logrus.Error("Broken input:", err) + return &Model{}, err + } + logrus.Println("Reading bpe model file with number of") + logrus.Println("Characters:", nChars) + logrus.Println("Rules of merge:", nRules) + + model = newModel(nRules) } - if _, ok := model.recipe[rule.left]; !ok { - logrus.Errorf("%d: token id not described before", rule.left) - return model, errors.New("token id is impossible") + if i < nChars+1 && i != 0 { + row = scanner.Text() + unicodeChar, err := strconv.Atoi(strings.Fields(row)[0]) + if err != nil { + logrus.Error("Broken input:", err) + return &Model{}, err + } + tokenId, err := strconv.Atoi(strings.Fields(row)[1]) + if err != nil { + logrus.Error("Broken input:", err) + return &Model{}, err + } + + char = rune(unicodeChar) + charID = TokenID(tokenId) + model.char2id[char] = charID + model.id2char[charID] = char + model.recipe[charID] = EncodedString{charID} + model.revRecipe[string(char)] = charID + if charID < minCharID || minCharID == 0 { + minCharID = charID + model.spaceID = charID + } } - if _, ok := model.recipe[rule.right]; !ok { - logrus.Errorf("%d: token id not described before", rule.right) - return model, errors.New("token id is impossible") + if i < nChars+nRules+1 && i >= nChars+1 { + fmt.Println(j) + row = scanner.Text() + + rule, err := rowToRule(row) + if err != nil { + return model, err + } + if _, ok := model.recipe[rule.left]; !ok { + logrus.Errorf("%d: token id not described before", rule.left) + return model, errors.New("token id is impossible") + } + if _, ok := model.recipe[rule.right]; !ok { + logrus.Errorf("%d: token id not described before", rule.right) + return model, errors.New("token id is impossible") + } + model.rules[j] = rule + model.rule2id[newTokenIDPair(rule.left, rule.right)] = j + model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...) + resultString, err := DecodeToken(model.recipe[rule.result], model.id2char) + if err != nil { + logrus.Error("Unexpected token id inside the rules: ", err) + return model, err + } + model.revRecipe[resultString] = rule.result + j++ } - model.rules[i] = rule - model.rule2id[newTokenIDPair(rule.left, rule.right)] = i - model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...) - resultString, err := DecodeToken(model.recipe[rule.result], model.id2char) - if err != nil { - logrus.Error("Unexpected token id inside the rules: ", err) - return model, err + + if i == nChars+nRules+1 { + row = scanner.Text() + specials, err := rowToSpecialTokens(row) + if err != nil { + return model, err + } + model.specialTokens = specials + model.revRecipe[bosToken] = TokenID(specials.bos) + model.revRecipe[eosToken] = TokenID(specials.eos) + model.revRecipe[unkToken] = TokenID(specials.unk) + model.revRecipe[padToken] = TokenID(specials.pad) } - model.revRecipe[resultString] = rule.result - } - specialTokensBuf := make([]byte, 16) - if _, err := io.ReadFull(reader, specialTokensBuf); err != nil { - logrus.Error("Broken input: ", err) - return &Model{}, err - } - specials, err := binaryToSpecialTokens(specialTokensBuf) - if err != nil { - return model, err + + i++ } - model.specialTokens = specials - model.revRecipe[bosToken] = TokenID(specials.bos) - model.revRecipe[eosToken] = TokenID(specials.eos) - model.revRecipe[unkToken] = TokenID(specials.unk) - model.revRecipe[padToken] = TokenID(specials.pad) - return model, err + return model, nil } // IDToToken returns string token corresponding to the given token id. diff --git a/bpe_test.go b/bpe_test.go index 3641c93..a32016e 100644 --- a/bpe_test.go +++ b/bpe_test.go @@ -45,42 +45,12 @@ func TestSpecialTokens_ToBinary(t *testing.T) { require.Equal(t, bytesArray, specials.toBinary()) } -func TestBinaryToSpecialTokens(t *testing.T) { - req := require.New(t) - bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0} - expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127} - specials, err := binaryToSpecialTokens(bytesArray) - req.NoError(err) - req.Equal(expected, specials) - bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0} - specials, err = binaryToSpecialTokens(bytesArray) - req.Error(err) - bytesArray = []byte{} - specials, err = binaryToSpecialTokens(bytesArray) - req.Error(err) -} - func TestRule_ToBinary(t *testing.T) { rule := rule{1, 2, 257} bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} require.Equal(t, bytesArray, rule.toBinary()) } -func TestBinaryToRule(t *testing.T) { - req := require.New(t) - expected := rule{1, 2, 257} - bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} - rule, err := binaryToRule(bytesArray) - req.NoError(err) - req.Equal(expected, rule) - bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1} - rule, err = binaryToRule(bytesArray) - req.Error(err) - bytesArray = []byte{} - rule, err = binaryToRule(bytesArray) - req.Error(err) -} - func TestReadModel(t *testing.T) { req := require.New(t) reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 6, diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e59caa3 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=