diff --git a/go.mod b/go.mod index e8e89f0..643f58e 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/uuid v1.3.0 github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.8.1 + github.com/stretchr/testify v1.8.3 github.com/urfave/cli v1.22.5 gopkg.in/yaml.v2 v2.4.0 ) @@ -21,6 +22,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -38,6 +40,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.0.1 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/tim-ywliu/nested-logrus-formatter v1.3.2 // indirect diff --git a/internal/context/ausf_context_init.go b/internal/context/ausf_context_init.go index 575a04f..17c6360 100644 --- a/internal/context/ausf_context_init.go +++ b/internal/context/ausf_context_init.go @@ -2,12 +2,13 @@ package context import ( "fmt" + "net/netip" "os" - "strconv" "github.com/google/uuid" "github.com/free5gc/ausf/internal/logger" + ausf_utils "github.com/free5gc/ausf/internal/utils" "github.com/free5gc/ausf/pkg/factory" "github.com/free5gc/openapi/models" ) @@ -23,36 +24,47 @@ func InitAusfContext(context *AUSFContext) { context.GroupID = configuration.GroupId context.NrfUri = configuration.NrfUri context.NrfCertPem = configuration.NrfCertPem - context.UriScheme = models.UriScheme(configuration.Sbi.Scheme) // default uri scheme - context.RegisterIPv4 = factory.AusfSbiDefaultIPv4 // default localhost - context.SBIPort = factory.AusfSbiDefaultPort // default port - if sbi != nil { - if sbi.RegisterIPv4 != "" { - context.RegisterIPv4 = sbi.RegisterIPv4 - } - if sbi.Port != 0 { - context.SBIPort = sbi.Port - } - - if sbi.Scheme == "https" { - context.UriScheme = models.UriScheme_HTTPS - } else { - context.UriScheme = models.UriScheme_HTTP - } - - context.BindingIPv4 = os.Getenv(sbi.BindingIPv4) - if context.BindingIPv4 != "" { - logger.InitLog.Info("Parsing ServerIPv4 address from ENV Variable.") - } else { - context.BindingIPv4 = sbi.BindingIPv4 - if context.BindingIPv4 == "" { - logger.InitLog.Warn("Error parsing ServerIPv4 address as string. Using the 0.0.0.0 address as default.") - context.BindingIPv4 = "0.0.0.0" - } - } + + if sbi.RegisterIP != "" { + context.RegisterIP = sbi.RegisterIP + } else if sbi.RegisterIPv4 != "" { + context.RegisterIP = sbi.RegisterIPv4 + } else { + context.RegisterIP = factory.AusfSbiDefaultIPv4 // default uri scheme + } + + if sbi.Port != 0 { + context.SBIPort = sbi.Port + } else { + context.SBIPort = factory.AusfSbiDefaultPort // default port } - context.Url = string(context.UriScheme) + "://" + context.RegisterIPv4 + ":" + strconv.Itoa(context.SBIPort) + if sbi.Scheme == "https" { + context.UriScheme = models.UriScheme_HTTPS + } else { + context.UriScheme = models.UriScheme_HTTP + } + + if bindingIP := os.Getenv(sbi.BindingIP); bindingIP != "" { + context.BindingIP = bindingIP + logger.InitLog.Info("Parsing ServerIP address from ENV Variable.") + } else if bindingIP := sbi.BindingIP; bindingIP != "" { + context.BindingIP = bindingIP + } else if bindingIPv4 := os.Getenv(sbi.BindingIPv4); bindingIPv4 != "" { + context.BindingIP = bindingIPv4 + logger.InitLog.Info("Parsing ServerIPv4 address from ENV Variable.") + } else if bindingIPv4 := sbi.BindingIPv4; bindingIPv4 != "" { + context.BindingIP = bindingIPv4 + } else { + logger.InitLog.Warn("Error parsing ServerIPv4 address as string. Using the 0.0.0.0 address as default.") + context.BindingIP = "0.0.0.0" + } + context.BindingIP = ausf_utils.BindingLookup(context.BindingIP) + + sbiRegisterIp := ausf_utils.RegisterAddr(context.RegisterIP) + sbiPort := uint16(context.SBIPort) + + context.Url = string(context.UriScheme) + "://" + netip.AddrPortFrom(sbiRegisterIp, sbiPort).String() context.PlmnList = append(context.PlmnList, configuration.PlmnSupportList...) // context.NfService @@ -74,8 +86,15 @@ func AddNfServices(serviceMap *map[models.ServiceName]models.NfService, config * nfService.ServiceName = models.ServiceName_NAUSF_AUTH var ipEndPoint models.IpEndPoint - ipEndPoint.Ipv4Address = context.RegisterIPv4 ipEndPoint.Port = int32(context.SBIPort) + + registerAddr := ausf_utils.RegisterAddr(context.RegisterIP) + if registerAddr.Is6() { + ipEndPoint.Ipv6Address = context.RegisterIP + } else if registerAddr.Is4() { + ipEndPoint.Ipv4Address = context.RegisterIP + } + ipEndPoints = append(ipEndPoints, ipEndPoint) var nfServiceVersion models.NfServiceVersion diff --git a/internal/context/context.go b/internal/context/context.go index 017ac74..d454383 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -16,8 +16,8 @@ type AUSFContext struct { NfId string GroupID string SBIPort int - RegisterIPv4 string - BindingIPv4 string + RegisterIP string + BindingIP string Url string UriScheme models.UriScheme NrfUri string diff --git a/internal/sbi/consumer/nrf_service.go b/internal/sbi/consumer/nrf_service.go index 0a5a01b..1c1ce30 100644 --- a/internal/sbi/consumer/nrf_service.go +++ b/internal/sbi/consumer/nrf_service.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "net/http" - "strconv" + "net/netip" "strings" "sync" "time" @@ -14,6 +14,7 @@ import ( ausf_context "github.com/free5gc/ausf/internal/context" "github.com/free5gc/ausf/internal/logger" + ausf_utils "github.com/free5gc/ausf/internal/utils" "github.com/free5gc/openapi" "github.com/free5gc/openapi/Nnrf_NFDiscovery" "github.com/free5gc/openapi/Nnrf_NFManagement" @@ -200,7 +201,13 @@ func (s *nnrfService) buildNfProfile(ausfContext *ausf_context.AUSFContext) (pro profile.NfInstanceId = ausfContext.NfId profile.NfType = models.NfType_AUSF profile.NfStatus = models.NfStatus_REGISTERED - profile.Ipv4Addresses = append(profile.Ipv4Addresses, ausfContext.RegisterIPv4) + + registerAddr := ausf_utils.RegisterAddr(ausfContext.RegisterIP) + if registerAddr.Is6() { + profile.Ipv6Addresses = append(profile.Ipv6Addresses, ausfContext.RegisterIP) + } else if registerAddr.Is4() { + profile.Ipv4Addresses = append(profile.Ipv4Addresses, ausfContext.RegisterIP) + } services := []models.NfService{} for _, nfService := range ausfContext.NfService { services = append(services, nfService) @@ -220,7 +227,7 @@ func (s *nnrfService) buildNfProfile(ausfContext *ausf_context.AUSFContext) (pro // }, // }, } - return + return profile, nil } func (s *nnrfService) GetUdmUrl(nrfUri string) string { @@ -235,13 +242,18 @@ func (s *nnrfService) GetUdmUrl(nrfUri string) string { nfDiscoverParam, ) if err != nil { - logger.ConsumerLog.Errorln("[Search UDM UEAU] ", err.Error(), "use defalt udmUrl", udmUrl) + logger.ConsumerLog.Errorln("[Search UDM UEAU] ", err.Error(), "use default udmUrl", udmUrl) } else if len(res.NfInstances) > 0 { udmInstance := res.NfInstances[0] - if len(udmInstance.Ipv4Addresses) > 0 && udmInstance.NfServices != nil { - ueauService := (*udmInstance.NfServices)[0] - ueauEndPoint := (*ueauService.IpEndPoints)[0] - udmUrl = string(ueauService.Scheme) + "://" + ueauEndPoint.Ipv4Address + ":" + strconv.Itoa(int(ueauEndPoint.Port)) + ueauService := (*udmInstance.NfServices)[0] + ueauEndPoint := (*ueauService.IpEndPoints)[0] + port := uint16(ueauEndPoint.Port) + if len(udmInstance.Ipv6Addresses) > 0 && udmInstance.NfServices != nil { + registerIp := ausf_utils.RegisterAddr(ueauEndPoint.Ipv6Address) + udmUrl = string(ueauService.Scheme) + "://" + netip.AddrPortFrom(registerIp, port).String() + } else if len(udmInstance.Ipv4Addresses) > 0 && udmInstance.NfServices != nil { + registerIp := ausf_utils.RegisterAddr(ueauEndPoint.Ipv4Address) + udmUrl = string(ueauService.Scheme) + "://" + netip.AddrPortFrom(registerIp, port).String() } } else { logger.ConsumerLog.Errorln("[Search UDM UEAU] len(NfInstances) = 0") diff --git a/internal/utils/net.go b/internal/utils/net.go new file mode 100644 index 0000000..e0d074e --- /dev/null +++ b/internal/utils/net.go @@ -0,0 +1,28 @@ +package net + +import ( + "net" + "net/netip" + + "github.com/free5gc/ausf/internal/logger" +) + +func BindingLookup(bindIP string) string { + ips, err := net.LookupIP(bindIP) + if err != nil { + logger.InitLog.Errorf("Resolve BindingIP hostname %s failed: %+v", bindIP, err) + } + return ips[0].String() +} + +func RegisterAddr(registerIP string) netip.Addr { + ips, err := net.LookupIP(registerIP) + if err != nil { + logger.InitLog.Errorf("Resolve RegisterIP hostname %s failed: %+v", registerIP, err) + } + ip, err := netip.ParseAddr(ips[0].String()) + if err != nil { + logger.InitLog.Errorf("Parse RegisterIP hostname %s failed: %+v", registerIP, err) + } + return ip +} diff --git a/internal/utils/net_test.go b/internal/utils/net_test.go new file mode 100644 index 0000000..7706b6a --- /dev/null +++ b/internal/utils/net_test.go @@ -0,0 +1,66 @@ +package net + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBindingLookupWithLocalhost(t *testing.T) { + expectedBindIP := "::1" + + bindIP := BindingLookup("localhost") + assert.Equal(t, bindIP, expectedBindIP) +} + +func TestBindingLookupWithIpV4(t *testing.T) { + expectedBindIP := "127.0.0.1" + + bindIP := BindingLookup("127.0.0.1") + assert.Equal(t, bindIP, expectedBindIP) +} + +func TestBindingLookupWithIpV6(t *testing.T) { + expectedBindIP := "2001:db8::1:0:0:1" + + bindIP := BindingLookup("2001:db8::1:0:0:1") + assert.Equal(t, bindIP, expectedBindIP) +} + +func TestRegisterAddrWithLocalhost(t *testing.T) { + expectedAddr, err := netip.ParseAddr("::1") + if err != nil { + t.Errorf("invalid expected IP: %+v", expectedAddr) + } + + addr := RegisterAddr("localhost") + if addr != expectedAddr { + t.Errorf("invalid IP: %+v", addr) + } + assert.Equal(t, addr, expectedAddr) +} + +func TestRegisterAddrWithIpV4(t *testing.T) { + expectedAddr, err := netip.ParseAddr("127.0.0.1") + if err != nil { + t.Errorf("invalid expected IP: %+v", expectedAddr) + } + + addr := RegisterAddr("127.0.0.1") + if addr != expectedAddr { + t.Errorf("invalid IP: %+v", addr) + } +} + +func TestRegisterAddrWithIpV6(t *testing.T) { + expectedAddr, err := netip.ParseAddr("2001:db8::1:0:0:1") + if err != nil { + t.Errorf("invalid expected IP: %+v", expectedAddr) + } + + addr := RegisterAddr("2001:db8::1:0:0:1") + if addr != expectedAddr { + t.Errorf("invalid IP: %+v", addr) + } +} diff --git a/pkg/factory/config.go b/pkg/factory/config.go index e61b852..6ca6584 100644 --- a/pkg/factory/config.go +++ b/pkg/factory/config.go @@ -7,6 +7,7 @@ package factory import ( "errors" "fmt" + "net/netip" "os" "strconv" "sync" @@ -14,6 +15,7 @@ import ( "github.com/asaskevich/govalidator" "github.com/free5gc/ausf/internal/logger" + ausf_utils "github.com/free5gc/ausf/internal/utils" "github.com/free5gc/openapi/models" ) @@ -50,7 +52,7 @@ func (c *Config) Validate() (bool, error) { } type Info struct { - Version string `yaml:"version,omitempty" valid:"required,in(1.0.3)"` + Version string `yaml:"version,omitempty" valid:"required,in(1.0.4)"` Description string `yaml:"description,omitempty" valid:"type(string)"` } @@ -106,16 +108,29 @@ func (c *Configuration) validate() (bool, error) { } type Sbi struct { - Scheme string `yaml:"scheme" valid:"scheme"` - RegisterIPv4 string `yaml:"registerIPv4,omitempty" valid:"host,required"` // IP that is registered at NRF. - BindingIPv4 string `yaml:"bindingIPv4,omitempty" valid:"host,required"` // IP used to run the server in the node. - Port int `yaml:"port,omitempty" valid:"port,required"` + Scheme string `yaml:"scheme" valid:"in(http|https)"` + RegisterIPv4 string `yaml:"registerIPv4,omitempty" valid:"host,optional"` // IP that is registered at NRF. + RegisterIP string `yaml:"registerIP,omitempty" valid:"host,optional"` // IP that is registered at NRF. + BindingIPv4 string `yaml:"bindingIPv4,omitempty" valid:"host,optional"` // IP used to run the server in the node. + BindingIP string `yaml:"bindingIP,omitempty" valid:"host,optional"` // IP used to run the server in the node. + Port int `yaml:"port,omitempty" valid:"port,required,with_register,with_binding"` Tls *Tls `yaml:"tls,omitempty" valid:"optional"` } func (s *Sbi) validate() (bool, error) { - govalidator.TagMap["scheme"] = govalidator.Validator(func(str string) bool { - return str == "https" || str == "http" + govalidator.CustomTypeTagMap.Set("with_register", func(i interface{}, context interface{}) bool { + switch v := context.(type) { + case Sbi: + return (v.RegisterIPv4 != "" && v.RegisterIP == "") || (v.RegisterIP != "" && v.RegisterIPv4 == "") + } + return false + }) + govalidator.CustomTypeTagMap.Set("with_binding", func(i interface{}, context interface{}) bool { + switch v := context.(type) { + case Sbi: + return (v.BindingIPv4 != "" && v.BindingIP == "") || (v.BindingIP != "" && v.BindingIPv4 == "") + } + return false }) if tls := s.Tls; tls != nil { @@ -240,7 +255,14 @@ func (c *Config) GetLogReportCaller() bool { func (c *Config) GetSbiBindingAddr() string { c.RLock() defer c.RUnlock() - return c.GetSbiBindingIP() + ":" + strconv.Itoa(c.GetSbiPort()) + + bindIP, err := netip.ParseAddr(c.GetSbiBindingIP()) + if err != nil { + logger.CfgLog.Warnf("Logger should not be nil") + return "" + } + sbiPort := uint16(c.GetSbiPort()) + return netip.AddrPortFrom(bindIP, sbiPort).String() } func (c *Config) GetSbiBindingIP() string { @@ -250,14 +272,20 @@ func (c *Config) GetSbiBindingIP() string { if c.Configuration == nil || c.Configuration.Sbi == nil { return bindIP } - if c.Configuration.Sbi.BindingIPv4 != "" { + if c.Configuration.Sbi.BindingIP != "" { + if bindIP = os.Getenv(c.Configuration.Sbi.BindingIP); bindIP != "" { + logger.CfgLog.Infof("Parsing ServerIP [%s] from ENV Variable", bindIP) + } else { + bindIP = c.Configuration.Sbi.BindingIP + } + } else if c.Configuration.Sbi.BindingIPv4 != "" { if bindIP = os.Getenv(c.Configuration.Sbi.BindingIPv4); bindIP != "" { logger.CfgLog.Infof("Parsing ServerIPv4 [%s] from ENV Variable", bindIP) } else { bindIP = c.Configuration.Sbi.BindingIPv4 } } - return bindIP + return ausf_utils.BindingLookup(bindIP) } func (c *Config) GetSbiPort() int {