diff --git a/pkg/client/client.go b/pkg/client/client.go index d46865b..a0dd37d 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -42,6 +42,7 @@ func (bc stdBleConnector) Scan(ctx context.Context, b bool, h ble.AdvHandler, f // BLEClientInt is a interface used to abstract BLEClient type BLEClientInt interface { + RawScanWithDuration(time.Duration, func(ble.Advertisement)) error RawScan(func(ble.Advertisement)) error ReadValue(string) ([]byte, error) RawConnect(ble.AdvFilter) error @@ -268,7 +269,17 @@ func (client *BLEClient) serverFilter(a ble.Advertisement) bool { // RawScan exposes underlying BLE scanner func (client *BLEClient) RawScan(handle func(ble.Advertisement)) error { - return client.bleConnector.Scan(client.ctx, true, handle, nil) + return client.scanWithCtx(client.ctx, handle) +} + +// RawScanWithDuration exposes underlying BLE scanner (with timeout) +func (client *BLEClient) RawScanWithDuration(duration time.Duration, handle func(ble.Advertisement)) error { + ctx, _ := context.WithTimeout(context.Background(), duration) + return client.scanWithCtx(ctx, handle) +} + +func (client *BLEClient) scanWithCtx(ctx context.Context, handle func(ble.Advertisement)) error { + return client.bleConnector.Scan(ctx, true, handle, nil) } func (client *BLEClient) scan() { diff --git a/pkg/forwarder/forwarder.go b/pkg/forwarder/forwarder.go index e7e0e2c..2f1e10f 100644 --- a/pkg/forwarder/forwarder.go +++ b/pkg/forwarder/forwarder.go @@ -17,6 +17,7 @@ import ( ) const ( + scanDuration = time.Second * 2 maxConnectAttempts = 5 errNotConnected = "Forwarder is not connected" errInvalidForwardReq = "Invalid forwarding request" @@ -115,18 +116,35 @@ func (forwarder *BLEForwarder) Run() error { return forwarder.forwardingServer.Run() } +func (forwarder *BLEForwarder) collectAdvirtisements() ([]ble.Advertisement, error) { + ret := []ble.Advertisement{} + mutex := sync.Mutex{} + err := forwarder.forwardingClient.RawScanWithDuration(scanDuration, func(a ble.Advertisement) { + mutex.Lock() + ret = append(ret, a) + mutex.Unlock() + }) + if err != nil && err.Error() == "context deadline exceeded" { + err = nil + } + return ret, err +} + func (forwarder *BLEForwarder) scanLoop() { - mutex := &sync.Mutex{} for { time.Sleep(client.ScanInterval) - forwarder.forwardingClient.RawScan(func(a ble.Advertisement) { - mutex.Lock() + advs, err := forwarder.collectAdvirtisements() + if err != nil { + e := errors.Wrap(err, "collectAdvirtisements error") + forwarder.listener.OnConnectionError(e) + } + for _, a := range advs { err := forwarder.onScanned(a) if err != nil { - forwarder.listener.OnError(err) + e := errors.Wrap(err, "onScanned error") + forwarder.listener.OnError(e) } - mutex.Unlock() - }) + } } } @@ -146,12 +164,12 @@ func (forwarder *BLEForwarder) onScanned(a ble.Advertisement) error { forwarder.rssiMap.Set(forwarder.addr, addr, rssi) isF := client.IsForwarder(a) var err error - if addr != forwarder.serverAddr && isF { + if !util.AddrEqualAddr(addr, forwarder.serverAddr) && isF { err = forwarder.updateNetworkState(addr) e := forwarder.reconnect() err = wrapError(err, e) } - if addr == forwarder.serverAddr || isF { + if util.AddrEqualAddr(addr, forwarder.serverAddr) || isF { e := forwarder.refreshShortestPath() err = wrapError(err, e) } @@ -194,13 +212,13 @@ func (forwarder *BLEForwarder) reconnect() error { func (forwarder *BLEForwarder) refreshShortestPath() error { path, err := util.ShortestPath(forwarder.rssiMap.GetAll(), forwarder.addr, forwarder.serverAddr) if err != nil { - return err + return errors.Wrap(err, "Could not calc shortest path.") } if len(path) < 2 { return fmt.Errorf("Invalid path to server: %s", path) } nextHop := path[1] - if forwarder.toConnectAddr != nextHop { + if !util.AddrEqualAddr(forwarder.toConnectAddr, nextHop) { forwarder.toConnectAddr = nextHop err = forwarder.keepTryConnect(nextHop) } @@ -210,24 +228,35 @@ func (forwarder *BLEForwarder) refreshShortestPath() error { func (forwarder *BLEForwarder) keepTryConnect(addr string) error { err := errors.New("") attempts := 0 + rssi := 0 for err != nil && attempts < maxConnectAttempts { - err = forwarder.connect(addr) + rssi, err = forwarder.connect(addr) + if err != nil { + e := errors.Wrap(err, "keepTryConnect single connection error") + forwarder.listener.OnConnectionError(e) + } attempts++ } - return err + forwarder.listener.OnClientConnected(addr, attempts, rssi) + return nil } -func (forwarder *BLEForwarder) connect(addr string) error { +func (forwarder *BLEForwarder) connect(addr string) (int, error) { forwarder.connectedAddr = "" + rssi := 0 err := forwarder.forwardingClient.RawConnect(func(a ble.Advertisement) bool { - return util.AddrEqualAddr(a.Address().String(), addr) + b := util.AddrEqualAddr(a.Address().String(), addr) + if b { + rssi = a.RSSI() + } + return b }) if err != nil { - return err + return 0, err } forwarder.connectedAddr = addr forwarder.connectionGraph.Set(forwarder.addr, addr) - return nil + return rssi, nil } func (forwarder *BLEForwarder) isConnected() bool { @@ -235,7 +264,7 @@ func (forwarder *BLEForwarder) isConnected() bool { } func (forwarder *BLEForwarder) isConnectedToServer() bool { - return forwarder.connectedAddr == forwarder.serverAddr + return util.AddrEqualAddr(forwarder.connectedAddr, forwarder.serverAddr) } func noop() {} diff --git a/pkg/forwarder/forwarder_test.go b/pkg/forwarder/forwarder_test.go index 662fdb8..e9a1edb 100644 --- a/pkg/forwarder/forwarder_test.go +++ b/pkg/forwarder/forwarder_test.go @@ -49,6 +49,10 @@ func (c dummyClient) RawScan(f func(ble.Advertisement)) error { return nil } +func (c dummyClient) RawScanWithDuration(_ time.Duration, f func(ble.Advertisement)) error { + return c.RawScan(f) +} + func (c dummyClient) ReadValue(uuid string) ([]byte, error) { return c.mockedReadValue.Bytes(), nil } diff --git a/pkg/models/connection_graph.go b/pkg/models/connection_graph.go index 4e3f15c..97a6ba2 100644 --- a/pkg/models/connection_graph.go +++ b/pkg/models/connection_graph.go @@ -3,6 +3,7 @@ package models import ( "bytes" "encoding/gob" + "strings" "sync" ) @@ -29,6 +30,8 @@ func (cg *ConnectionGraph) Data() ([]byte, error) { // Set will update the map func (cg *ConnectionGraph) Set(src, new string) { + src = strings.ToUpper(src) + new = strings.ToUpper(new) cg.mutex.Lock() cg.data[src] = new cg.mutex.Unlock() @@ -41,6 +44,7 @@ func (cg *ConnectionGraph) GetAll() map[string]string { // Get will get from map func (cg *ConnectionGraph) Get(src string) (string, bool) { + src = strings.ToUpper(src) cg.mutex.RLock() ret, ok := cg.data[src] cg.mutex.RUnlock() diff --git a/pkg/models/rssi_map.go b/pkg/models/rssi_map.go index 104a1d4..b6f842c 100644 --- a/pkg/models/rssi_map.go +++ b/pkg/models/rssi_map.go @@ -3,6 +3,7 @@ package models import ( "bytes" "encoding/gob" + "strings" "sync" ) @@ -29,6 +30,8 @@ func (rm *RssiMap) Data() ([]byte, error) { // Set will update the map func (rm *RssiMap) Set(src, dst string, new int) { + src = strings.ToUpper(src) + dst = strings.ToUpper(dst) rm.mutex.Lock() if _, ok := rm.data[src]; !ok { rm.data[src] = map[string]int{} @@ -44,6 +47,8 @@ func (rm *RssiMap) GetAll() map[string]map[string]int { // Get will get from map func (rm *RssiMap) Get(src, dst string) (int, bool) { + src = strings.ToUpper(src) + dst = strings.ToUpper(dst) rm.mutex.RLock() if tmp, ok := rm.data[src]; ok { ret, oke := tmp[dst] diff --git a/pkg/models/rssi_map_test.go b/pkg/models/rssi_map_test.go index 97c37e1..e9a89b6 100644 --- a/pkg/models/rssi_map_test.go +++ b/pkg/models/rssi_map_test.go @@ -8,13 +8,13 @@ import ( func TestSetter(t *testing.T) { x := NewRssiMap() - x.Set("a", "b", 1) - assert.DeepEqual(t, x.data, map[string]map[string]int{"a": map[string]int{"b": 1}}) + x.Set("A", "B", 1) + assert.DeepEqual(t, x.data, map[string]map[string]int{"A": map[string]int{"B": 1}}) } func TestMerge(t *testing.T) { - a := "a" - b := "b" + a := "A" + b := "B" x := NewRssiMap() y := NewRssiMap() z := NewRssiMap() diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 7f15f61..f03a573 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -60,7 +60,7 @@ func TestStatusSetters(t *testing.T) { assert.Equal(t, len(errs), 1) assert.DeepEqual(t, errs[0].Error(), expected.Error()) addr := "someaddr" - expectedRM := map[string]map[string]int{"a": map[string]int{"b": -90}} + expectedRM := map[string]map[string]int{"A": map[string]int{"B": -90}} s := BLEClientState{Status: Connected, RssiMap: expectedRM, ConnectedAddr: addr} server.setClientState(addr, s) expectedState := map[string]BLEClientState{} diff --git a/pkg/util/timeout_test.go b/pkg/util/timeout_test.go index 61e700e..917b093 100644 --- a/pkg/util/timeout_test.go +++ b/pkg/util/timeout_test.go @@ -10,7 +10,7 @@ import ( func TestTimeout(t *testing.T) { err := Timeout(func() error { - time.Sleep(timeout + 2) + time.Sleep(timeout + time.Second) return errors.New("should not get called") }, timeout) assert.ErrorContains(t, err, "Timeout")