Skip to content

Commit b8d95f3

Browse files
committed
Refactor parse args
1 parent 83f644c commit b8d95f3

File tree

5 files changed

+73
-59
lines changed

5 files changed

+73
-59
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/cloverstd/tcping
22

3-
go 1.17
3+
go 1.18
44

55
require (
66
github.com/smartystreets/goconvey v1.7.2

main.go

+44-54
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ import (
44
"fmt"
55
"os"
66
"os/signal"
7-
"syscall"
8-
"time"
9-
107
"strconv"
8+
"syscall"
119

1210
"github.com/cloverstd/tcping/ping"
1311
"github.com/spf13/cobra"
@@ -43,78 +41,70 @@ var rootCmd = cobra.Command{
4341
3. ping over http
4442
> tcping -H google.com
4543
4. ping with URI schema
46-
> tcping http://hui.lu
44+
> tcping https://hui.lu
4745
`,
4846
Run: func(cmd *cobra.Command, args []string) {
4947
if showVersion {
5048
fmt.Printf("version: %s\n", version)
5149
fmt.Printf("git: %s\n", gitCommit)
5250
return
5351
}
54-
if len(args) != 2 && len(args) != 1 {
52+
if len(args) == 0 {
5553
cmd.Usage()
5654
return
5755
}
56+
if len(args) > 2 {
57+
cmd.Println("invalid command arguments")
58+
return
59+
}
5860
host := args[0]
5961

62+
url, err := ping.ParseAddress(host)
63+
if err != nil {
64+
fmt.Printf("%s is an invalid target.\n", host)
65+
return
66+
}
67+
defaultPort := "80"
68+
if len(args) > 1 {
69+
defaultPort = args[1]
70+
}
71+
port, err := strconv.Atoi(defaultPort)
72+
if err != nil {
73+
cmd.Printf("%s is invalid port.\n", defaultPort)
74+
return
75+
}
76+
url.Host = fmt.Sprintf("%s:%d", url.Hostname(), port)
77+
6078
var (
61-
err error
62-
port int
6379
schema string
6480
)
65-
if len(args) == 2 {
66-
port, err = strconv.Atoi(args[1])
67-
if err != nil {
68-
fmt.Println("port should be integer")
69-
cmd.Usage()
70-
return
71-
}
72-
schema = ping.TCP.String()
73-
} else {
74-
var matched bool
75-
schema, host, port, matched = ping.CheckURI(host)
76-
if !matched {
77-
fmt.Println("not a valid uri")
78-
cmd.Usage()
79-
return
80-
}
81-
}
82-
var timeoutDuration time.Duration
83-
if res, err := strconv.Atoi(timeout); err == nil {
84-
timeoutDuration = time.Duration(res) * time.Millisecond
85-
} else {
86-
timeoutDuration, err = time.ParseDuration(timeout)
87-
if err != nil {
88-
fmt.Println("parse timeout failed", err)
89-
cmd.Usage()
90-
return
91-
}
81+
82+
timeoutDuration, err := ping.ParseDuration(timeout)
83+
if err != nil {
84+
cmd.Println("parse timeout failed", err)
85+
cmd.Usage()
86+
return
9287
}
9388

94-
var intervalDuration time.Duration
95-
if res, err := strconv.Atoi(interval); err == nil {
96-
intervalDuration = time.Duration(res) * time.Millisecond
97-
} else {
98-
intervalDuration, err = time.ParseDuration(interval)
99-
if err != nil {
100-
fmt.Println("parse interval failed", err)
101-
cmd.Usage()
102-
return
103-
}
89+
intervalDuration, err := ping.ParseDuration(interval)
90+
if err != nil {
91+
cmd.Println("parse interval failed", err)
92+
cmd.Usage()
93+
return
10494
}
105-
var protocol ping.Protocol
95+
10696
if httpMode {
107-
protocol = ping.HTTP
108-
} else {
109-
protocol, err = ping.NewProtocol(schema)
110-
if err != nil {
111-
fmt.Println(err)
112-
cmd.Usage()
113-
return
114-
}
97+
url.Scheme = ping.HTTP.String()
98+
}
99+
protocol, err := ping.NewProtocol(url.Scheme)
100+
if err != nil {
101+
cmd.Println("invalid protocol", err)
102+
cmd.Usage()
103+
return
115104
}
105+
116106
if len(dnsServer) != 0 {
117-
ping.UseCustomeDNS(dnsServer)
107+
ping.UseCustomDNS(dnsServer)
118108
}
119109

120110
parseHost, _ := ping.FormatIP(host)

ping/address.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package ping
2+
3+
import (
4+
pkgurl "net/url"
5+
"strings"
6+
)
7+
8+
// ParseAddress will try to parse addr as url.URL.
9+
func ParseAddress(addr string) (*pkgurl.URL, error) {
10+
if strings.Contains(addr, "://") {
11+
// it maybe with scheme, try url.Parse
12+
return pkgurl.Parse(addr)
13+
}
14+
return pkgurl.Parse("tcp://" + addr)
15+
}

ping/ping.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func (protocol Protocol) String() string {
2222
case HTTPS:
2323
return "https"
2424
}
25-
return "unkown"
25+
return "unknown"
2626
}
2727

2828
const (
@@ -34,7 +34,7 @@ const (
3434
HTTPS
3535
)
3636

37-
// NewProtocol convert protocol stirng to Protocol
37+
// NewProtocol convert protocol string to Protocol
3838
func NewProtocol(protocol string) (Protocol, error) {
3939
switch strings.ToLower(protocol) {
4040
case TCP.String():

ping/utils.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net"
7+
"strconv"
78
"strings"
89
"time"
910
)
@@ -15,8 +16,8 @@ func timeIt(f func() interface{}) (int64, interface{}) {
1516
return endAt.UnixNano() - startAt.UnixNano(), res
1617
}
1718

18-
// UseCustomeDNS will set the dns to default DNS resolver for global
19-
func UseCustomeDNS(dns []string) {
19+
// UseCustomDNS will set the dns to default DNS resolver for global
20+
func UseCustomDNS(dns []string) {
2021
resolver := net.Resolver{
2122
PreferGo: true,
2223
Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) {
@@ -53,3 +54,11 @@ func FormatIP(IP string) (string, error) {
5354
}
5455
return "", fmt.Errorf("Error IP format")
5556
}
57+
58+
// ParseDuration parse the t as time.Duration, it will parse t as mills when missing unit.
59+
func ParseDuration(t string) (time.Duration, error) {
60+
if timeout, err := strconv.ParseInt(t, 10, 64); err == nil {
61+
return time.Duration(timeout) * time.Millisecond, nil
62+
}
63+
return time.ParseDuration(t)
64+
}

0 commit comments

Comments
 (0)