diff --git a/client/client.go b/client/client.go index 311b800a..bcb11642 100644 --- a/client/client.go +++ b/client/client.go @@ -15,6 +15,7 @@ package client import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -28,9 +29,11 @@ import ( "github.com/dapr/go-sdk/actor" "github.com/dapr/go-sdk/actor/config" + "github.com/dapr/go-sdk/client/internal" "github.com/dapr/go-sdk/version" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -43,6 +46,7 @@ import ( const ( daprPortDefault = "50001" daprPortEnvVarName = "DAPR_GRPC_PORT" /* #nosec */ + daprGRPCEndpointEnvVarName = "DAPR_GRPC_ENDPOINT" traceparentKey = "traceparent" apiTokenKey = "dapr-api-token" /* #nosec */ apiTokenEnvVarName = "DAPR_API_TOKEN" /* #nosec */ @@ -219,18 +223,28 @@ type Client interface { // NewClientWithConnection(conn *grpc.ClientConn) Client // NewClientWithSocket(socket string) (client Client, err error) func NewClient() (client Client, err error) { - port := os.Getenv(daprPortEnvVarName) - if port == "" { - port = daprPortDefault - } - if defaultClient != nil { - return defaultClient, nil - } lock.Lock() defer lock.Unlock() + if defaultClient != nil { return defaultClient, nil } + + addr, ok := os.LookupEnv(daprGRPCEndpointEnvVarName) + if ok { + client, err = NewClientWithAddress(addr) + if err != nil { + return nil, fmt.Errorf("error creating %q client: %w", daprGRPCEndpointEnvVarName, err) + } + defaultClient = client + return defaultClient, nil + } + + port, ok := os.LookupEnv(daprPortEnvVarName) + if !ok { + port = daprPortDefault + } + c, err := NewClientWithPort(port) if err != nil { return nil, fmt.Errorf("error creating default client: %w", err) @@ -266,13 +280,28 @@ func NewClientWithAddressContext(ctx context.Context, address string) (client Cl if err != nil { return nil, err } + + parsedAddress, err := internal.ParseGRPCEndpoint(address) + if err != nil { + return nil, fmt.Errorf("error parsing address '%s': %w", address, err) + } + + opts := []grpc.DialOption{ + grpc.WithUserAgent(userAgent()), + grpc.WithBlock(), + } + + if parsedAddress.TLS { + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(new(tls.Config)))) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) conn, err := grpc.DialContext( ctx, - address, - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUserAgent(userAgent()), - grpc.WithBlock(), + parsedAddress.Target, + opts..., ) cancel() if err != nil { diff --git a/client/internal/parse.go b/client/internal/parse.go new file mode 100644 index 00000000..a6f2bfc3 --- /dev/null +++ b/client/internal/parse.go @@ -0,0 +1,177 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "errors" + "fmt" + "net" + "net/url" + "strings" +) + +// Parsed represents a parsed gRPC endpoint. +type Parsed struct { + Target string + TLS bool +} + +//nolint:revive +func ParseGRPCEndpoint(endpoint string) (Parsed, error) { + target := endpoint + if len(target) == 0 { + return Parsed{}, errors.New("target is required") + } + + var dnsAuthority string + var hostname string + var tls bool + + urlSplit := strings.Split(target, ":") + if len(urlSplit) == 3 && !strings.Contains(target, "://") { + target = strings.Replace(target, ":", "://", 1) + } else if len(urlSplit) >= 2 && !strings.Contains(target, "://") && schemeKnown(urlSplit[0]) { + target = strings.Replace(target, ":", "://", 1) + } else { + urlSplit = strings.Split(target, "://") + if len(urlSplit) == 1 { + target = "dns://" + target + } else { + scheme := urlSplit[0] + if !schemeKnown(scheme) { + return Parsed{}, fmt.Errorf(("unknown scheme: %q"), scheme) + } + + if scheme == "dns" { + urlSplit = strings.Split(target, "/") + if len(urlSplit) < 4 { + return Parsed{}, fmt.Errorf("invalid dns scheme: %q", target) + } + dnsAuthority = urlSplit[2] + target = "dns://" + urlSplit[3] + } + } + } + + ptarget, err := url.Parse(target) + if err != nil { + return Parsed{}, err + } + + var errs []string + for k := range ptarget.Query() { + if k != "tls" { + errs = append(errs, fmt.Sprintf("unrecognized query parameter: %q", k)) + } + } + if len(errs) > 0 { + return Parsed{}, fmt.Errorf("failed to parse target %q: %s", target, strings.Join(errs, "; ")) + } + + if ptarget.Query().Has("tls") { + if ptarget.Scheme == "http" || ptarget.Scheme == "https" { + return Parsed{}, errors.New("cannot use tls query parameter with http(s) scheme") + } + + qtls := ptarget.Query().Get("tls") + if qtls != "true" && qtls != "false" { + return Parsed{}, fmt.Errorf("invalid value for tls query parameter: %q", qtls) + } + + tls = qtls == "true" + } + + scheme := ptarget.Scheme + if scheme == "https" { + tls = true + } + if scheme == "http" || scheme == "https" { + scheme = "dns" + } + + hostname = ptarget.Host + + host, port, err := net.SplitHostPort(hostname) + aerr, ok := err.(*net.AddrError) + if ok && aerr.Err == "missing port in address" { + port = "443" + } else if err != nil { + return Parsed{}, err + } else { + hostname = host + } + + if len(hostname) == 0 { + if scheme == "dns" { + hostname = "localhost" + } else { + hostname = ptarget.Path + } + } + + switch scheme { + case "unix": + separator := ":" + if strings.HasPrefix(endpoint, "unix://") { + separator = "://" + } + target = scheme + separator + hostname + + case "vsock": + target = scheme + ":" + hostname + ":" + port + + case "unix-abstract": + target = scheme + ":" + hostname + + case "dns": + if len(ptarget.Path) > 0 { + return Parsed{}, fmt.Errorf("path is not allowed: %q", ptarget.Path) + } + + if strings.Count(hostname, ":") == 7 && !strings.HasPrefix(hostname, "[") && !strings.HasSuffix(hostname, "]") { + hostname = "[" + hostname + "]" + } + if len(dnsAuthority) > 0 { + dnsAuthority = "//" + dnsAuthority + "/" + } + target = scheme + ":" + dnsAuthority + hostname + ":" + port + + default: + return Parsed{}, fmt.Errorf("unsupported scheme: %q", scheme) + } + + return Parsed{ + Target: target, + TLS: tls, + }, nil +} + +func schemeKnown(scheme string) bool { + for _, s := range []string{ + "dns", + "unix", + "unix-abstract", + "vsock", + "http", + "https", + "grpc", + "grpcs", + } { + if scheme == s { + return true + } + } + + return false +} diff --git a/client/internal/parse_test.go b/client/internal/parse_test.go new file mode 100644 index 00000000..46bbbb35 --- /dev/null +++ b/client/internal/parse_test.go @@ -0,0 +1,293 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package internal + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + tests := map[string]struct { + expTarget string + expTLS bool + expError bool + }{ + "": { + expTarget: "", + expTLS: false, + expError: true, + }, + ":5000": { + expTarget: "dns:localhost:5000", + expTLS: false, + expError: false, + }, + ":5000?tls=false": { + expTarget: "dns:localhost:5000", + expTLS: false, + expError: false, + }, + ":5000?tls=true": { + expTarget: "dns:localhost:5000", + expTLS: true, + expError: false, + }, + "myhost": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "myhost?tls=false": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "myhost?tls=true": { + expTarget: "dns:myhost:443", + expTLS: true, + expError: false, + }, + "myhost:443": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "myhost:443?tls=false": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "myhost:443?tls=true": { + expTarget: "dns:myhost:443", + expTLS: true, + expError: false, + }, + "http://myhost": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "http://myhost?tls=false": { + expTarget: "", + expTLS: false, + expError: true, + }, + "http://myhost?tls=true": { + expTarget: "", + expTLS: false, + expError: true, + }, + "http://myhost:443": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "http://myhost:443?tls=false": { + expTarget: "", + expTLS: false, + expError: true, + }, + "http://myhost:443?tls=true": { + expTarget: "", + expTLS: false, + expError: true, + }, + "http://myhost:5000": { + expTarget: "dns:myhost:5000", + expTLS: false, + expError: false, + }, + "http://myhost:5000?tls=false": { + expTarget: "", + expTLS: false, + expError: true, + }, + "http://myhost:5000?tls=true": { + expTarget: "", + expTLS: false, + expError: true, + }, + "https://myhost:443": { + expTarget: "dns:myhost:443", + expTLS: true, + expError: false, + }, + "https://myhost:443/tls=false": { + expTarget: "", + expTLS: false, + expError: true, + }, + "https://myhost:443?tls=true": { + expTarget: "", + expTLS: false, + expError: true, + }, + "dns:myhost": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "dns:myhost?tls=false": { + expTarget: "dns:myhost:443", + expTLS: false, + expError: false, + }, + "dns:myhost?tls=true": { + expTarget: "dns:myhost:443", + expTLS: true, + expError: false, + }, + "dns://myauthority:53/myhost": { + expTarget: "dns://myauthority:53/myhost:443", + expTLS: false, + expError: false, + }, + "dns://myauthority:53/myhost?tls=false": { + expTarget: "dns://myauthority:53/myhost:443", + expTLS: false, + expError: false, + }, + "dns://myauthority:53/myhost?tls=true": { + expTarget: "dns://myauthority:53/myhost:443", + expTLS: true, + expError: false, + }, + "dns://myhost": { + expTarget: "", + expTLS: false, + expError: true, + }, + "unix:my.sock": { + expTarget: "unix:my.sock", + expTLS: false, + expError: false, + }, + "unix:my.sock?tls=true": { + expTarget: "unix:my.sock", + expTLS: true, + expError: false, + }, + "unix://my.sock": { + expTarget: "unix://my.sock", + expTLS: false, + expError: false, + }, + "unix:///my.sock": { + expTarget: "unix:///my.sock", + expTLS: false, + expError: false, + }, + "unix://my.sock?tls=true": { + expTarget: "unix://my.sock", + expTLS: true, + expError: false, + }, + "unix-abstract:my.sock": { + expTarget: "unix-abstract:my.sock", + expTLS: false, + expError: false, + }, + "unix-abstract:my.sock?tls=false": { + expTarget: "unix-abstract:my.sock", + expTLS: false, + expError: false, + }, + "unix-abstract:my.sock?tls=true": { + expTarget: "unix-abstract:my.sock", + expTLS: true, + expError: false, + }, + "vsock:mycid:5000": { + expTarget: "vsock:mycid:5000", + expTLS: false, + expError: false, + }, + "vsock:mycid:5000?tls=false": { + expTarget: "vsock:mycid:5000", + expTLS: false, + expError: false, + }, + "vsock:mycid:5000?tls=true": { + expTarget: "vsock:mycid:5000", + expTLS: true, + expError: false, + }, + "dns:1.2.3.4:443": { + expTarget: "dns:1.2.3.4:443", + expTLS: false, + expError: false, + }, + "dns:[2001:db8:1f70::999:de8:7648:6e8]:443": { + expTarget: "dns:[2001:db8:1f70::999:de8:7648:6e8]:443", + expTLS: false, + expError: false, + }, + "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000": { + expTarget: "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", + expTLS: false, + expError: false, + }, + "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000?abc=[]": { + expTarget: "", + expTLS: false, + expError: true, + }, + "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]": { + expTarget: "dns://myauthority:53/[2001:db8:1f70::999:de8:7648:6e8]:443", + expTLS: false, + expError: false, + }, + "https://[2001:db8:1f70::999:de8:7648:6e8]": { + expTarget: "dns:[2001:db8:1f70::999:de8:7648:6e8]:443", + expTLS: true, + expError: false, + }, + "https://[2001:db8:1f70::999:de8:7648:6e8]:5000": { + expTarget: "dns:[2001:db8:1f70::999:de8:7648:6e8]:5000", + expTLS: true, + expError: false, + }, + "host:5000/v1/dapr": { + expTarget: "", + expTLS: false, + expError: true, + }, + "host:5000/?a=1": { + expTarget: "", + expTLS: false, + expError: true, + }, + "inv-scheme://myhost": { + expTarget: "", + expTLS: false, + expError: true, + }, + "inv-scheme:myhost:5000": { + expTarget: "", + expTLS: false, + expError: true, + }, + } + + for url, tc := range tests { + t.Run(url, func(t *testing.T) { + parsed, err := ParseGRPCEndpoint(url) + assert.Equalf(t, tc.expError, err != nil, "%v", err) + assert.Equal(t, tc.expTarget, parsed.Target) + assert.Equal(t, tc.expTLS, parsed.TLS) + }) + } +}