Skip to content

Commit

Permalink
fix change of peer identifier (public key) (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
h44z committed Jan 5, 2025
1 parent 6d86f15 commit 3020fbc
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 68 deletions.
2 changes: 1 addition & 1 deletion cmd/wg-portal/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func main() {
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
internal.AssertNoError(err)

statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, database, wireGuard, metricsServer)
statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer)
internal.AssertNoError(err)

cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem)
Expand Down
88 changes: 71 additions & 17 deletions internal/adapters/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ import (
"context"
"errors"
"fmt"
"github.com/sirupsen/logrus"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"
"os"
"path/filepath"
"strings"
"time"

"github.com/sirupsen/logrus"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/utils"

"github.com/glebarez/sqlite"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
Expand Down Expand Up @@ -204,7 +205,8 @@ func (r *SqlRepo) preCheck() error {
return nil // we probably don't have a V1 database =)
}

return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version)
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
lastVersion.Version)
}

func (r *SqlRepo) migrate() error {
Expand Down Expand Up @@ -249,7 +251,11 @@ func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifie
return &in, nil
}

func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) {
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
*domain.Interface,
[]domain.Peer,
error,
) {
in, err := r.GetInterface(ctx, id)
if err != nil {
return nil, nil, fmt.Errorf("failed to load interface: %w", err)
Expand Down Expand Up @@ -305,7 +311,11 @@ func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.I
return users, nil
}

func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error {
func (r *SqlRepo) SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error {
userInfo := domain.GetUserInfo(ctx)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreateInterface(userInfo, tx, id)
Expand Down Expand Up @@ -333,7 +343,11 @@ func (r *SqlRepo) SaveInterface(ctx context.Context, id domain.InterfaceIdentifi
return nil
}

func (r *SqlRepo) getOrCreateInterface(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.Interface, error) {
func (r *SqlRepo) getOrCreateInterface(
ui *domain.ContextUserInfo,
tx *gorm.DB,
id domain.InterfaceIdentifier,
) (*domain.Interface, error) {
var in domain.Interface

// interfaceDefaults will be applied to newly created interface records
Expand Down Expand Up @@ -449,7 +463,10 @@ func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIden
return peers, nil
}

func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) {
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
[]domain.Peer,
error,
) {
var peers []domain.Peer

searchValue := "%" + strings.ToLower(search) + "%"
Expand Down Expand Up @@ -492,7 +509,11 @@ func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, s
return peers, nil
}

func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error {
func (r *SqlRepo) SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error {
userInfo := domain.GetUserInfo(ctx)
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
peer, err := r.getOrCreatePeer(userInfo, tx, id)
Expand Down Expand Up @@ -520,7 +541,10 @@ func (r *SqlRepo) SavePeer(ctx context.Context, id domain.PeerIdentifier, update
return nil
}

func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (*domain.Peer, error) {
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
*domain.Peer,
error,
) {
var peer domain.Peer

// interfaceDefaults will be applied to newly created interface records
Expand Down Expand Up @@ -601,7 +625,10 @@ func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]d
return result, nil
}

func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) {
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
map[domain.Cidr][]domain.Cidr,
error,
) {
var peerIps []struct {
domain.Cidr
PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
Expand Down Expand Up @@ -699,7 +726,11 @@ func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User,
return users, nil
}

func (r *SqlRepo) SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error {
func (r *SqlRepo) SaveUser(
ctx context.Context,
id domain.UserIdentifier,
updateFunc func(u *domain.User) (*domain.User, error),
) error {
userInfo := domain.GetUserInfo(ctx)

err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
Expand Down Expand Up @@ -737,7 +768,10 @@ func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) erro
return nil
}

func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (*domain.User, error) {
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
*domain.User,
error,
) {
var user domain.User

// userDefaults will be applied to newly created user records
Expand Down Expand Up @@ -777,7 +811,11 @@ func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *doma

// region statistics

func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error {
func (r *SqlRepo) UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreateInterfaceStatus(tx, id)
if err != nil {
Expand All @@ -804,7 +842,10 @@ func (r *SqlRepo) UpdateInterfaceStatus(ctx context.Context, id domain.Interface
return nil
}

func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (*domain.InterfaceStatus, error) {
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
*domain.InterfaceStatus,
error,
) {
var in domain.InterfaceStatus

// defaults will be applied to newly created record
Expand All @@ -830,7 +871,11 @@ func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus)
return nil
}

func (r *SqlRepo) UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error {
func (r *SqlRepo) UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error {
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
in, err := r.getOrCreatePeerStatus(tx, id)
if err != nil {
Expand Down Expand Up @@ -883,6 +928,15 @@ func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
return nil
}

func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
if err != nil {
return err
}

return nil
}

// endregion statistics

// region audit
Expand Down
1 change: 1 addition & 0 deletions internal/app/eventbus.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ const TopicRouteUpdate = "route:update"
const TopicRouteRemove = "route:remove"
const TopicInterfaceUpdated = "interface:updated"
const TopicPeerInterfaceUpdated = "peer:interface:updated"
const TopicPeerIdentifierUpdated = "peer:identifier:updated"
44 changes: 37 additions & 7 deletions internal/app/wireguard/repos.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ type InterfaceAndPeerDatabaseRepo interface {
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error)
GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
SaveInterface(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error)) error
SaveInterface(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.Interface) (*domain.Interface, error),
) error
DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error)
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error)
SavePeer(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error)) error
SavePeer(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.Peer) (*domain.Peer, error),
) error
DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
Expand All @@ -30,18 +38,40 @@ type StatisticsDatabaseRepo interface {
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)

UpdatePeerStatus(ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error)) error
UpdateInterfaceStatus(ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error)) error
UpdatePeerStatus(
ctx context.Context,
id domain.PeerIdentifier,
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
) error
UpdateInterfaceStatus(
ctx context.Context,
id domain.InterfaceIdentifier,
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
) error

DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error
}

type InterfaceController interface {
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (*domain.PhysicalPeer, error)
SaveInterface(_ context.Context, id domain.InterfaceIdentifier, updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error)) error
GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) (
*domain.PhysicalPeer,
error,
)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier, updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error)) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
}

Expand Down
Loading

0 comments on commit 3020fbc

Please sign in to comment.