diff --git a/pkg/entities/offchain_tip_bot.go b/pkg/entities/offchain_tip_bot.go index 829b5dfd..ec8f610b 100644 --- a/pkg/entities/offchain_tip_bot.go +++ b/pkg/entities/offchain_tip_bot.go @@ -18,6 +18,7 @@ import ( "github.com/defipod/mochi/pkg/response" "github.com/defipod/mochi/pkg/service/mochipay" "github.com/defipod/mochi/pkg/util" + sliceutils "github.com/defipod/mochi/pkg/util/slice" ) func (e *Entity) TransferToken(req request.OffchainTransferRequest) (*response.OffchainTipBotTransferToken, error) { @@ -170,29 +171,47 @@ func (e *Entity) sendLogNotify(req request.OffchainTransferRequest, decimal int, } } +func (e *Entity) getTransferToken(req request.TransferV2Request) (token *mochipay.Token, err error) { + logger := e.log.Fields(logger.Fields{"component": "entity.getTransferToken", "req": req}) + + if req.TokenId != "" { + token, err = e.svc.MochiPay.GetTokenById(req.TokenId) + } else { + token, err = e.svc.MochiPay.GetToken(req.Token, req.ChainID) + } + + if err != nil { + logger.Error(err, "failed to get token") + return nil, err + } + + if token == nil { + err = errors.New("token not found") + logger.Error(err, "token not found") + return nil, err + } + + return +} + func (e *Entity) TransferTokenV2(req request.TransferV2Request) (*response.TransferTokenV2Data, error) { - logger := e.log.Fields(logger.Fields{"component": "entity.TransferV2", "req": req}) + logger := e.log.Fields(logger.Fields{"component": "entity.TransferTokenV2", "req": req}) logger.Info("receive new transfer request") template := parseTemplate(req) // validate token - token, err := e.svc.MochiPay.GetToken(req.Token, req.ChainID) + token, err := e.getTransferToken(req) if err != nil { - logger.Error(err, "[entity.TransferTokenV2] svc.MochiPay.GetToken() failed") + logger.Error(err, "getTransferToken() failed") return nil, err } - if token == nil { - logger.Error(err, "[entity.TransferTokenV2] token not found") - return nil, errors.New(consts.OffchainTipBotFailReasonTokenNotSupported) - } - // convert total transfer amount totalAmount := util.FloatToBigInt(req.Amount, token.Decimal) // validate balance if err := e.validateTransferBalance(totalAmount, req); err != nil { - logger.Error(err, "[entity.TransferTokenV2] svc.MochiPay.GetToken() failed") + logger.Error(err, "validateTransferBalance() failed") return nil, err } @@ -329,7 +348,7 @@ func (e *Entity) validateTransferBalance(total *big.Int, req request.TransferV2R // validate balance senderBalance, err := e.svc.MochiPay.GetBalance(req.Sender, req.Token, req.ChainID) if err != nil { - e.log.Fields(logger.Fields{"token": req.Token, "user": req.Sender}).Error(err, "[entity.TransferTokenV2] repo.OffchainTipBotUserBalances.GetUserBalanceByTokenID() failed") + e.log.Fields(logger.Fields{"token": req.Token, "user": req.Sender}).Error(err, "[entity.TransferTokenV2] svc.MochiPay.GetBalance() failed") return err } @@ -337,12 +356,21 @@ func (e *Entity) validateTransferBalance(total *big.Int, req request.TransferV2R return errors.New(consts.OffchainTipBotFailReasonNotEnoughBalance) } - bal, err := util.StringToBigInt(senderBalance.Data[0].Amount) + var bal *mochipay.GetBalanceResponse + if req.TokenId != "" { + bal = sliceutils.Find(senderBalance.Data, func(b mochipay.GetBalanceResponse) bool { + return b.TokenId == req.TokenId + }) + } else { + bal = &senderBalance.Data[0] + } + + currentAmount, err := util.StringToBigInt(bal.Amount) if err != nil { return errors.New(consts.OffchainTipBotFailReasonInvalidAmount) } - if bal.Cmp(total) < 0 { + if currentAmount.Cmp(total) < 0 { return errors.New(consts.OffchainTipBotFailReasonNotEnoughBalance) } diff --git a/pkg/request/offchain_tip_bot.go b/pkg/request/offchain_tip_bot.go index 6e121c8b..5f325b5e 100644 --- a/pkg/request/offchain_tip_bot.go +++ b/pkg/request/offchain_tip_bot.go @@ -31,6 +31,7 @@ type TransferV2Request struct { GuildID string `json:"guild_id"` Amount float64 `json:"amount" binding:"required"` Token string `json:"token" binding:"required"` + TokenId string `json:"token_id"` Each bool `json:"each"` All bool `json:"all"` TransferType string `json:"transfer_type" binding:"required" enums:"transfer,airdrop"` diff --git a/pkg/service/mochipay/mochipay.go b/pkg/service/mochipay/mochipay.go index dca89252..415210a7 100644 --- a/pkg/service/mochipay/mochipay.go +++ b/pkg/service/mochipay/mochipay.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" @@ -437,7 +438,7 @@ func (m *MochiPay) GetToken(symbol, chainId string) (*Token, error) { return nil, err } - responseBody, err := ioutil.ReadAll(response.Body) + responseBody, err := io.ReadAll(response.Body) if err != nil { return nil, err } @@ -611,3 +612,20 @@ func (m *MochiPay) GetProfileKrystalEarnBalances(profileID string) (any, error) return nil, nil } + +func (m *MochiPay) GetTokenById(tokenId string) (*Token, error) { + var res GetTokenResponse + status, err := util.SendRequest(util.SendRequestQuery{ + URL: fmt.Sprintf("%s/api/v1/tokens/%s", m.config.MochiPayServerHost, tokenId), + Method: "GET", + Response: &res, + }) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("transfer failed with status %d", status) + } + + return res.Data, nil +} diff --git a/pkg/service/mochipay/service.go b/pkg/service/mochipay/service.go index af0b70a2..0c3b7956 100644 --- a/pkg/service/mochipay/service.go +++ b/pkg/service/mochipay/service.go @@ -19,6 +19,7 @@ type Service interface { GetProfileCustodialWallets(profileID string) (any, error) GetProfileKrystalEarnBalances(profileID string) (any, error) GetStakingTokenMapping(symbol, address string) (*StakingTokenMappingResponse, error) + GetTokenById(tokenId string) (*Token, error) // TransferV2 TransferV2(req TransferV2Request) (*TransferV2Response, error) diff --git a/pkg/util/slice/slice.go b/pkg/util/slice/slice.go index b5a0c02c..7e2aba0e 100644 --- a/pkg/util/slice/slice.go +++ b/pkg/util/slice/slice.go @@ -83,3 +83,14 @@ func Equal[T constraints.Ordered](a []T, b []T) bool { }) return reflect.DeepEqual(a, b) } + +// returns the first element of the given slice, which pass the provided callback function ('callbackFn' returns true) +func Find[T any](s []T, callbackFn func(elem T) bool) *T { + for _, item := range s { + if callbackFn(item) { + return &item + } + } + + return nil +}