From ef81c6af2217650fb545c45301ff53d713d6c2d0 Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Thu, 23 May 2024 02:15:43 +0400 Subject: [PATCH] Add batch mode for provider diagnostic endpoint Signed-off-by: Anton Litvinov --- cmd/bootstrap.go | 96 +++++------ cmd/di.go | 4 - core/connection/manager-diag.go | 1 - core/connection/manager_test.go | 1 - services/wireguard/connection/connection.go | 2 - services/wireguard/endpoint/endpoint.go | 2 - tequilapi/endpoints/connection-diag.go | 171 +++++++++++++++++--- 7 files changed, 202 insertions(+), 75 deletions(-) diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index b0bae35d4e..5024f8fdde 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -45,56 +45,60 @@ func (di *Dependencies) bootstrapTequilapi(nodeOptions node.Options, listener ne } tequilaApiClient := tequilapi_client.NewClient(nodeOptions.TequilapiAddress, nodeOptions.TequilapiPort) + handlers := []func(engine *gin.Engine) error{ + func(e *gin.Engine) error { + if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil { + return err + } + return nil + }, + func(e *gin.Engine) error { + if config.GetBool(config.FlagPProfEnable) { + tequilapi_endpoints.AddRoutesForPProf(e) + } + return nil + }, + func(e *gin.Engine) error { + e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck) + return nil + }, + tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)), + tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), + tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), + tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), + tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), + tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), + tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), + tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient), + tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)), + tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber), + tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)), + tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker), + tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI), + tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator), + tequilapi_endpoints.AddRoutesForConfig, + tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator), + tequilapi_endpoints.AddRoutesForFeedback(di.Reporter), + tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage), + tequilapi_endpoints.AddRoutesForDocs, + tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI), + tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver), + tequilapi_endpoints.AddRoutesForTerms, + tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator( + config.FlagPaymentPriceGiB.Value, + config.FlagPaymentPriceHour.Value, + )), + tequilapi_endpoints.AddRoutesForValidator, + } + if nodeOptions.ProvChecker { + handlers = append(handlers, tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions)) + } + return tequilapi.NewServer( listener, nodeOptions, di.JWTAuthenticator, - []func(engine *gin.Engine) error{ - func(e *gin.Engine) error { - if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil { - return err - } - return nil - }, - func(e *gin.Engine) error { - if config.GetBool(config.FlagPProfEnable) { - tequilapi_endpoints.AddRoutesForPProf(e) - } - return nil - }, - func(e *gin.Engine) error { - e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck) - return nil - }, - tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)), - tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), - tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), - tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), - tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions), - tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), - tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), - tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), - tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient), - tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)), - tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber), - tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)), - tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker), - tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI), - tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator), - tequilapi_endpoints.AddRoutesForConfig, - tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator), - tequilapi_endpoints.AddRoutesForFeedback(di.Reporter), - tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage), - tequilapi_endpoints.AddRoutesForDocs, - tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI), - tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver), - tequilapi_endpoints.AddRoutesForTerms, - tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator( - config.FlagPaymentPriceGiB.Value, - config.FlagPaymentPriceHour.Value, - )), - tequilapi_endpoints.AddRoutesForValidator, - }, + handlers, ) } diff --git a/cmd/di.go b/cmd/di.go index 726fbfa90c..114b9f936a 100644 --- a/cmd/di.go +++ b/cmd/di.go @@ -952,10 +952,6 @@ func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality, return err } - if nodeOptions.ProvChecker { - // di.provPinger = connection.NewProviderChecker(di.EventBus) - } - return nil } diff --git a/core/connection/manager-diag.go b/core/connection/manager-diag.go index 1a984877db..c1f7a2abef 100644 --- a/core/connection/manager-diag.go +++ b/core/connection/manager-diag.go @@ -957,7 +957,6 @@ func (m *diagConnectionManager) sendKeepAlivePing(ctx context.Context, channel p return err } - _ = start m.eventBus.Publish(quality.AppTopicConsumerPingP2P, quality.PingEvent{ SessionID: string(sessionID), Duration: time.Since(start), diff --git a/core/connection/manager_test.go b/core/connection/manager_test.go index 61f7a5e5b4..4e7ec1da1d 100644 --- a/core/connection/manager_test.go +++ b/core/connection/manager_test.go @@ -61,7 +61,6 @@ type testContext struct { statsReportInterval time.Duration mockP2P *mockP2PDialer mockTime time.Time - sync.RWMutex } diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go index 7909dfd815..690226df04 100644 --- a/services/wireguard/connection/connection.go +++ b/services/wireguard/connection/connection.go @@ -115,8 +115,6 @@ func (c *Connection) Reconnect(ctx context.Context, options connection.ConnectOp } func (c *Connection) start(ctx context.Context, start startConn, options connection.ConnectOptions) (err error) { - log.Info().Msg("+++++++++++++++++++++++++++++++++++++++++++++++++++++ *Connection) start") - var config wg.ServiceConfig if err = json.Unmarshal(options.SessionConfig, &config); err != nil { return errors.Wrap(err, "failed to unmarshal connection config") diff --git a/services/wireguard/endpoint/endpoint.go b/services/wireguard/endpoint/endpoint.go index c9b8c7bc73..ad1d1ddd28 100644 --- a/services/wireguard/endpoint/endpoint.go +++ b/services/wireguard/endpoint/endpoint.go @@ -88,8 +88,6 @@ func (ce *connectionEndpoint) StartConsumerMode(cfg wgcfg.DeviceConfig) error { } return errors.Wrap(err, "could not configure device") } - - // ce.wgClient.Diag() return nil } diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index 3ff08bc4f8..5c144a2000 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -19,11 +19,13 @@ package endpoints import ( "fmt" + "sort" "github.com/ethereum/go-ethereum/common" "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/rs/zerolog/log" + "gvisor.dev/gvisor/pkg/sync" "github.com/mysteriumnetwork/go-rest/apierror" "github.com/mysteriumnetwork/node/config" @@ -51,11 +53,13 @@ type ConnectionDiagEndpoint struct { identityRegistry identityRegistry addressProvider addressProvider identitySelector selector.Handler + + consumerAddress string } // NewConnectionDiagEndpoint creates and returns connection endpoint func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint { - return &ConnectionDiagEndpoint{ + ce := &ConnectionDiagEndpoint{ manager: manager, publisher: publisher, subscriber: subscriber, @@ -65,19 +69,153 @@ func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider sta addressProvider: addressProvider, identitySelector: identitySelector, } -} - -// Diag is used to start provider check -func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { - log.Debug().Msgf("Diag >>>") chainID := config.GetInt64(config.FlagChainID) consumerID_, err := ce.identitySelector.UseOrCreate(config.FlagIdentity.Value, config.FlagIdentityPassphrase.Value, chainID) if err != nil { - c.Error(apierror.Internal("Failed to unlock identity", err.Error())) + panic(err) + } + log.Error().Msgf("Unlocked identity: %v", consumerID_.Address) + ce.consumerAddress = consumerID_.Address + + return ce +} + +func dedupeSortedStrings(s []string) []string { + if len(s) < 2 { + return s + } + var e = 1 + for i := 1; i < len(s); i++ { + if s[i] == s[i-1] { + continue + } + s[e] = s[i] + e++ + } + + return s[:e] +} + +// DiagBatch is used to start a given providers check (batch mode) +func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { + hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) + if err != nil { + c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) return } - log.Error().Msgf("Unlocked identity: %v", consumerID_) + + provs := make([]string, 0) + c.Bind(&provs) + sort.Strings(provs) + provs = dedupeSortedStrings(provs) + + var ( + wg sync.WaitGroup + mu sync.Mutex + ) + resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(provs)) + wg.Add(len(provs)) + + for _, prov := range provs { + go func(prov string) { + result := contract.ConnectionDiagInfoDTO{ + ProviderID: prov, + } + defer func() { + mu.Lock() + resultMap[prov] = result + mu.Unlock() + + wg.Done() + }() + + cr := &contract.ConnectionCreateRequest{ + ConsumerID: ce.consumerAddress, + ProviderID: prov, + Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, + HermesID: hermes.Hex(), + ServiceType: "wireguard", + ConnectOptions: contract.ConnectOptions{}, + } + if err := cr.Validate(); err != nil { + result.Error = err + return + } + + consumerID := identity.FromAddress(cr.ConsumerID) + status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) + if err != nil { + log.Error().Err(err).Stack().Msg("Could not check registration status") + result.Error = contract.ErrCodeIDRegistrationCheck + return + } + switch status { + case registry.Unregistered, registry.RegistrationError, registry.Unknown: + log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) + result.Error = contract.ErrCodeIDNotRegistered + return + case registry.InProgress: + log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) + case registry.Registered: + log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) + default: + log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) + result.Error = contract.ErrCodeIDStatusUnknown + return + } + + if len(cr.ProviderID) > 0 { + cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) + } + f := &proposal.Filter{ + ServiceType: cr.ServiceType, + LocationCountry: cr.Filter.CountryCode, + ProviderIDs: cr.Filter.Providers, + IPType: cr.Filter.IPType, + IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, + AccessPolicy: "all", + } + proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + + if ce.manager.HasConnection(cr.ProviderID) { + result.Error = contract.ErrCodeConnectionAlreadyExists + return + } + + err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) + if err != nil { + switch err { + case connection.ErrAlreadyExists: + result.Error = contract.ErrCodeConnectionAlreadyExists + case connection.ErrConnectionCancelled: + result.Error = contract.ErrCodeConnectionCancelled + default: + log.Error().Err(err).Msgf("Failed to connect: %v", prov) + result.Error = contract.ErrCodeConnect + } + return + } + + resChannel := ce.manager.GetReadyChan(cr.ProviderID) + res := <-resChannel + log.Error().Msgf("Result > %v", res) + result.Status = res.(quality.DiagEvent).Result + + }(prov) + } + wg.Wait() + + out := make([]contract.ConnectionDiagInfoDTO, 0) + for _, prov := range provs { + out = append(out, resultMap[prov]) + } + utils.WriteAsJSON(out, c.Writer) +} + +// Diag is used to start a given provider check +func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { + log.Debug().Msgf("Diag >>>") hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) if err != nil { @@ -91,14 +229,13 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { return } cr := &contract.ConnectionCreateRequest{ - ConsumerID: consumerID_.Address, + ConsumerID: ce.consumerAddress, ProviderID: prov, Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, HermesID: hermes.Hex(), ServiceType: "wireguard", ConnectOptions: contract.ConnectOptions{}, } - if err := cr.Validate(); err != nil { c.Error(err) return @@ -111,7 +248,6 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck)) return } - switch status { case registry.Unregistered, registry.RegistrationError, registry.Unknown: log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) @@ -130,7 +266,6 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { if len(cr.ProviderID) > 0 { cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) } - f := &proposal.Filter{ ServiceType: cr.ServiceType, LocationCountry: cr.Filter.CountryCode, @@ -141,8 +276,7 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { } proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) - hasConnection := ce.manager.HasConnection(cr.ProviderID) - if hasConnection { + if ce.manager.HasConnection(cr.ProviderID) { c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) return } @@ -179,18 +313,17 @@ func AddRoutesForConnectionDiag( proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, - publisher2 eventbus.Subscriber, + subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler, options node.Options, ) func(*gin.Engine) error { - ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, publisher2, addressProvider, identitySelector) + ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, subscriber, addressProvider, identitySelector) return func(e *gin.Engine) error { connGroup := e.Group("") { - if options.ProvChecker { - connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) - } + connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) + connGroup.POST("/prov-checker-batch", ConnectionDiagEndpoint.DiagBatch) } return nil }