Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:binary handlers #33

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions hivemind_bus_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,33 @@
from threading import Event
from typing import Union, Optional, Callable

import pgpy
from ovos_bus_client import Message as MycroftMessage, MessageBusClient as OVOSBusClient
from ovos_bus_client.session import Session
from ovos_utils.log import LOG
from ovos_utils.messagebus import FakeBus
from pyee import EventEmitter
from websocket import ABNF
from websocket import WebSocketApp, WebSocketConnectionClosedException
import pgpy
from hivemind_bus_client.serialization import HiveMindBinaryPayloadType

from hivemind_bus_client.identity import NodeIdentity
from hivemind_bus_client.message import HiveMessage, HiveMessageType
from hivemind_bus_client.serialization import HiveMindBinaryPayloadType
from hivemind_bus_client.serialization import get_bitstring, decode_bitstring
from hivemind_bus_client.util import serialize_message, \
encrypt_as_json, decrypt_from_json, encrypt_bin, decrypt_bin
from ovos_utils.log import LOG
from ovos_utils.messagebus import FakeBus


class BinaryDataCallbacks:
def handle_receive_tts(self, bin_data: bytes,
utterance: str,
lang: str,
file_name: str):
LOG.warning(f"Ignoring received binary TTS audio: {utterance} with {len(bin_data)} bytes")

def handle_receive_file(self, bin_data: bytes,
file_name: str):
LOG.warning(f"Ignoring received binary file: {file_name} with {len(bin_data)} bytes")
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved


class HiveMessageWaiter:
Expand Down Expand Up @@ -87,7 +100,9 @@ def __init__(self, key: Optional[str] = None,
compress: bool = True,
binarize: bool = True,
identity: NodeIdentity = None,
internal_bus: Optional[OVOSBusClient] = None):
internal_bus: Optional[OVOSBusClient] = None,
bin_callbacks: BinaryDataCallbacks = BinaryDataCallbacks()):
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
self.bin_callbacks = bin_callbacks

self.identity = identity or None
self._password = password
Expand Down Expand Up @@ -268,20 +283,45 @@ def on_message(self, *args):
message = decode_bitstring(message)
elif isinstance(message, str):
message = json.loads(message)
if "ciphertext" in message:
if isinstance(message, dict) and "ciphertext" in message:
LOG.error("got encrypted message, but could not decrypt!")
return

if (isinstance(message, HiveMessage) and message.msg_type == HiveMessageType.BINARY):
self._handle_binary(message)
return
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
self.emitter.emit('message', message) # raw message
self._handle_hive_protocol(HiveMessage(**message))

def _handle_binary(self, message: HiveMessage):
assert message.msg_type == HiveMessageType.BINARY
bin_data = message.payload
LOG.debug(f"Got binary data of type: {message.bin_type}")
if message.bin_type == HiveMindBinaryPayloadType.TTS_AUDIO:
lang = message.metadata.get("lang")
utt = message.metadata.get("utterance")
file_name = message.metadata.get("file_name")
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
try:
self.bin_callbacks.handle_receive_tts(bin_data, utt, lang, file_name)
except:
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
LOG.exception("Error in binary callback: handle_receive_tts")
elif message.bin_type == HiveMindBinaryPayloadType.FILE:
file_name = message.metadata.get("file_name")
try:
self.bin_callbacks.handle_receive_file(bin_data, file_name)
except:
LOG.exception("Error in binary callback: handle_receive_file")
else:
LOG.warning(f"Ignoring received untyped binary data: {len(bin_data)} bytes")

def _handle_hive_protocol(self, message: HiveMessage):
# LOG.debug(f"received HiveMind message: {message.msg_type}")
if message.msg_type == HiveMessageType.BUS:
self.internal_bus.emit(message.payload)
self.emitter.emit(message.msg_type, message) # hive message

def emit(self, message: Union[MycroftMessage, HiveMessage],
binary_type: HiveMindBinaryPayloadType=HiveMindBinaryPayloadType.UNDEFINED):
binary_type: HiveMindBinaryPayloadType = HiveMindBinaryPayloadType.UNDEFINED):
if isinstance(message, MycroftMessage):
message = HiveMessage(msg_type=HiveMessageType.BUS,
payload=message)
Expand Down Expand Up @@ -324,7 +364,8 @@ def emit(self, message: Union[MycroftMessage, HiveMessage],
bitstr = get_bitstring(hive_type=message.msg_type,
payload=message.payload,
compressed=self.compress,
binary_type=binary_type)
binary_type=binary_type,
hivemeta=message.metadata)
if self.crypto_key:
ws_payload = encrypt_bin(self.crypto_key, bitstr.bytes)
else:
Expand Down
19 changes: 16 additions & 3 deletions hivemind_bus_client/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from ovos_bus_client import Message
from ovos_utils.json_helper import merge_dict
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict, Any


class HiveMessageType(str, Enum):
Expand Down Expand Up @@ -36,6 +36,9 @@ class HiveMindBinaryPayloadType(IntEnum):
RAW_AUDIO = 1 # binary content is raw audio (TODO spec exactly what "raw audio" means)
NUMPY_IMAGE = 2 # binary content is an image as a numpy array, eg. webcam picture
FILE = 3 # binary is a file to be saved, additional metadata provided elsewhere
STT_AUDIO_TRANSCRIBE = 4 # full audio sentence to perform STT and return transcripts
STT_AUDIO_HANDLE = 5 # full audio sentence to perform STT and handle transcription immediately
TTS_AUDIO = 6 # synthesized TTS audio to be played


class HiveMessage:
Expand All @@ -47,7 +50,8 @@ def __init__(self, msg_type: Union[HiveMessageType, str],
target_peers: Optional[List[str]]=None,
target_site_id: Optional[str] =None,
target_pubkey: Optional[str] =None,
bin_type: HiveMindBinaryPayloadType = HiveMindBinaryPayloadType.UNDEFINED):
bin_type: HiveMindBinaryPayloadType = HiveMindBinaryPayloadType.UNDEFINED,
metadata: Optional[Dict[str, Any]] = None):
# except for the hivemind node classes receiving the message and
# creating the object nothing should be able to change these values
# node classes might change them a runtime by the private attribute
Expand All @@ -59,6 +63,7 @@ def __init__(self, msg_type: Union[HiveMessageType, str],

self._msg_type = msg_type
self._bin_type = bin_type
self._meta = metadata or {}

# the payload is more or less a free for all
# the msg_type determines what happens to the message, but the
Expand All @@ -82,6 +87,10 @@ def __init__(self, msg_type: Union[HiveMessageType, str],
self._route = route or [] # where did this message come from
self._targets = target_peers or [] # where will it be sent

@property
def metadata(self) -> Dict[str, Any]:
return self._meta

JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
@property
def target_site_id(self) -> str:
return self._site_id
Expand Down Expand Up @@ -145,6 +154,7 @@ def as_dict(self) -> dict:

return {"msg_type": self.msg_type,
"payload": pload,
"metadata": self.metadata,
"route": self.route,
"node": self.node_id,
"target_site_id": self.target_site_id,
Expand All @@ -166,6 +176,7 @@ def deserialize(payload: Union[str, dict]) -> 'HiveMessage':
if "msg_type" in payload:
try:
return HiveMessage(payload["msg_type"], payload["payload"],
metadata=payload.get("metadata", {}),
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
target_site_id=payload.get("target_site_id"),
target_pubkey=payload.get("target_pubkey"))
except:
Expand All @@ -175,13 +186,15 @@ def deserialize(payload: Union[str, dict]) -> 'HiveMessage':
try:
# NOTE: technically could also be SHARED_BUS or THIRDPRTY
return HiveMessage(HiveMessageType.BUS,
Message.deserialize(payload),
payload=Message.deserialize(payload),
metadata=payload.get("metadata", {}),
target_site_id=payload.get("target_site_id"),
target_pubkey=payload.get("target_pubkey"))
except:
pass # not a mycroft message

return HiveMessage(HiveMessageType.THIRDPRTY, payload,
metadata=payload.get("metadata", {}),
target_site_id=payload.get("target_site_id"),
target_pubkey=payload.get("target_pubkey"))

Expand Down
5 changes: 2 additions & 3 deletions hivemind_bus_client/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ def _decode_bitstring_v1(s):
payload = bytes2str(payload.bytes, compressed)
else:
payload = payload.bytes
meta["bin_type"] = bin_type

kwargs = {a: meta[a] for a in signature(HiveMessage).parameters if a in meta}
return HiveMessage(hive_type, payload, **kwargs)
return HiveMessage(hive_type, payload,
metadata=meta, bin_type=bin_type)


def mycroft2bitstring(msg, compressed=False):
Expand Down
Loading