Skip to content

Commit

Permalink
Working through typehinting rest of spot wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
khughes-bdai committed Jan 26, 2024
1 parent 53cd7da commit 963a3fe
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 56 deletions.
53 changes: 30 additions & 23 deletions spot_wrapper/cam_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import datetime
import enum
import logging
import math
import os.path
import pathlib
Expand All @@ -17,7 +18,9 @@
from bosdyn.api.data_chunk_pb2 import DataChunk
from bosdyn.api.spot_cam import audio_pb2
from bosdyn.api.spot_cam.camera_pb2 import Camera
from bosdyn.api.spot_cam.compositor_pb2 import IrColorMap
from bosdyn.api.spot_cam.logging_pb2 import Logpoint
from bosdyn.api.spot_cam.power_pb2 import PowerStatus
from bosdyn.api.spot_cam.ptz_pb2 import PtzDescription, PtzPosition, PtzVelocity
from bosdyn.client import Robot, spot_cam
from bosdyn.client.payload import PayloadClient
Expand Down Expand Up @@ -50,7 +53,7 @@ class LEDPosition(enum.Enum):
FRONT_RIGHT = 2
REAR_RIGHT = 3

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.logger = logger
self.client: LightingClient = robot.ensure_client(LightingClient.default_service_name)

Expand Down Expand Up @@ -82,11 +85,11 @@ class PowerWrapper:
Wrapper for power interaction
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.logger = logger
self.client: PowerClient = robot.ensure_client(PowerClient.default_service_name)

def get_power_status(self):
def get_power_status(self) -> PowerStatus:
"""
Get power status for the devices
"""
Expand Down Expand Up @@ -134,7 +137,7 @@ class CompositorWrapper:
Wrapper for compositor interaction
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.logger = logger
self.client: CompositorClient = robot.ensure_client(CompositorClient.default_service_name)

Expand Down Expand Up @@ -175,7 +178,7 @@ def get_screen(self) -> str:
"""
return self.client.get_screen()

def set_ir_colormap(self, colormap, min_temp: float, max_temp: float, auto_scale: bool = True) -> None:
def set_ir_colormap(self, colormap: IrColorMap, min_temp: float, max_temp: float, auto_scale: bool = True) -> None:
"""
Set the colormap used for the IR camera
Expand Down Expand Up @@ -205,7 +208,7 @@ class HealthWrapper:
Wrapper for health details
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.client: HealthClient = robot.ensure_client(HealthClient.default_service_name)
self.logger = logger

Expand Down Expand Up @@ -250,7 +253,7 @@ class AudioWrapper:
Wrapper for audio commands on the camera
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.client: AudioClient = robot.ensure_client(AudioClient.default_service_name)
self.logger = logger

Expand Down Expand Up @@ -334,11 +337,11 @@ class StreamQualityWrapper:
Wrapper for stream quality commands
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.client: StreamQualityClient = robot.ensure_client(StreamQualityClient.default_service_name)
self.logger = logger

def set_stream_params(self, target_bitrate: int, refresh_interval: int, idr_interval: int, awb) -> None:
def set_stream_params(self, target_bitrate: int, refresh_interval: int, idr_interval: int, awb: typing.Any) -> None:
"""
Set image compression and postprocessing parameters
Expand Down Expand Up @@ -402,7 +405,7 @@ class MediaLogWrapper:
Some functionality adapted from https://github.com/boston-dynamics/spot-sdk/blob/master/python/examples/spot_cam/media_log.py
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.client: MediaLogClient = robot.ensure_client(MediaLogClient.default_service_name)
self.logger = logger

Expand Down Expand Up @@ -656,7 +659,7 @@ class PTZWrapper:
Wrapper for controlling the PTZ unit
"""

def __init__(self, robot: Robot, logger) -> None:
def __init__(self, robot: Robot, logger: logging.Logger) -> None:
self.client: PtzClient = robot.ensure_client(PtzClient.default_service_name)
self.logger = logger
self.ptzs = {}
Expand All @@ -681,7 +684,7 @@ def list_ptz(self) -> typing.Dict[str, typing.Dict]:

return ptzs

def _get_ptz_description(self, name):
def _get_ptz_description(self, name: str) -> PtzDescription:
"""
Get the bosdyn version of the ptz description
Expand All @@ -697,7 +700,7 @@ def _get_ptz_description(self, name):

return self.ptzs[name]

def _clamp_value_to_limits(self, value, limits: PtzDescription.Limits):
def _clamp_value_to_limits(self, value: float, limits: PtzDescription.Limits) -> float:
"""
Clamp the given value to the specified limits. If the limits are unspecified (i.e. both 0), the value is not
clamped
Expand All @@ -717,7 +720,9 @@ def _clamp_value_to_limits(self, value, limits: PtzDescription.Limits):

return max(min(value, limits.max.value), limits.min.value)

def _clamp_request_to_limits(self, ptz_name, pan, tilt, zoom) -> typing.Tuple[float, float, float]:
def _clamp_request_to_limits(
self, ptz_name: str, pan: float, tilt: float, zoom: float
) -> typing.Tuple[float, float, float]:
"""
Args:
Expand All @@ -734,7 +739,7 @@ def _clamp_request_to_limits(self, ptz_name, pan, tilt, zoom) -> typing.Tuple[fl
self._clamp_value_to_limits(zoom, ptz_desc.zoom_limit),
)

def get_ptz_position(self, ptz_name) -> PtzPosition:
def get_ptz_position(self, ptz_name: str) -> PtzPosition:
"""
Get the position of the ptz with the given name
Expand All @@ -746,7 +751,7 @@ def get_ptz_position(self, ptz_name) -> PtzPosition:
"""
return self.client.get_ptz_position(PtzDescription(name=ptz_name))

def set_ptz_position(self, ptz_name, pan, tilt, zoom, blocking=False):
def set_ptz_position(self, ptz_name: str, pan: float, tilt: float, zoom: float, blocking: bool = False) -> None:
"""
Set the position of the specified ptz
Expand All @@ -771,7 +776,7 @@ def set_ptz_position(self, ptz_name, pan, tilt, zoom, blocking=False):
current_position = self.client.get_ptz_position(self._get_ptz_description(ptz_name))
time.sleep(0.2)

def get_ptz_velocity(self, ptz_name) -> PtzVelocity:
def get_ptz_velocity(self, ptz_name: str) -> PtzVelocity:
"""
Get the velocity of the ptz with the given name
Expand All @@ -783,7 +788,7 @@ def get_ptz_velocity(self, ptz_name) -> PtzVelocity:
"""
return self.client.get_ptz_velocity(PtzDescription(name=ptz_name))

def set_ptz_velocity(self, ptz_name, pan, tilt, zoom) -> None:
def set_ptz_velocity(self, ptz_name: str, pan: float, tilt: float, zoom: float) -> None:
"""
Set the velocity of the various axes of the specified ptz
Expand Down Expand Up @@ -820,10 +825,10 @@ def __init__(
self,
hostname: str,
robot: Robot,
logger,
sdp_port=31102,
sdp_filename="h264.sdp",
cam_ssl_cert_path=None,
logger: logging.Logger,
sdp_port: int = 31102,
sdp_filename: str = "h264.sdp",
cam_ssl_cert_path: typing.Optional[str] = None,
) -> None:
"""
Initialise the wrapper
Expand Down Expand Up @@ -895,7 +900,9 @@ async def _process_func(self) -> None:


class SpotCamWrapper:
def __init__(self, hostname, username, password, logger, port: typing.Optional[int] = None) -> None:
def __init__(
self, hostname: str, username: str, password: str, logger: logging.Logger, port: typing.Optional[int] = None
) -> None:
self._hostname = hostname
self._username = username
self._password = password
Expand Down
7 changes: 4 additions & 3 deletions spot_wrapper/spot_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def manipulation_command(self, request: manipulation_api_pb2) -> typing.Tuple[bo
timesync_endpoint=self._robot.time_sync.endpoint,
)

def get_manipulation_command_feedback(self, cmd_id):
def get_manipulation_command_feedback(self, cmd_id: int) -> manipulation_api_pb2.ManipulationApiFeedbackResponse:
feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest(manipulation_cmd_id=cmd_id)

return self._manipulation_api_client.manipulation_api_feedback_command(
Expand Down Expand Up @@ -136,7 +136,7 @@ def ensure_arm_power_and_stand(self) -> typing.Tuple[bool, str]:

return True, "Spot has an arm, is powered on, and standing"

def wait_for_arm_command_to_complete(self, cmd_id, timeout_sec: typing.Optional[float] = None) -> None:
def wait_for_arm_command_to_complete(self, cmd_id: int, timeout_sec: typing.Optional[float] = None) -> None:
"""
Wait until a command issued to the arm complets. Wrapper around the SDK function for convenience
Expand Down Expand Up @@ -302,7 +302,8 @@ def create_wrench_from_forces_and_torques(
torque = geometry_pb2.Vec3(x=torques[0], y=torques[1], z=torques[2])
return geometry_pb2.Wrench(force=force, torque=torque)

def force_trajectory(self, data) -> typing.Tuple[bool, str]:
def force_trajectory(self, data: typing.Any) -> typing.Tuple[bool, str]:
# TODO here data is ArmForceTrajectory from spot_msgs ROS package. How to enforce this type?
try:
success, msg = self.ensure_arm_power_and_stand()
if not success:
Expand Down
2 changes: 1 addition & 1 deletion spot_wrapper/spot_eap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
client: PointCloudClient,
logger: logging.Logger,
rate: float,
callback: typing.Callable,
callback: typing.Optional[typing.Callable],
point_cloud_requests: typing.List[PointCloudRequest],
) -> None:
"""
Expand Down
32 changes: 16 additions & 16 deletions spot_wrapper/spot_graph_nav.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _get_lease(self) -> Lease:
def _init_current_graph_nav_state(self) -> None:
# Store the most recent knowledge of the state of the robot based on rpc calls.
self._current_graph = None
self._current_edges = {} # maps to_waypoint to list(from_waypoint)
self._current_edges: typing.Dict[str, typing.List[str]] = {} # maps to_waypoint to list(from_waypoint)
self._current_waypoint_snapshots = {} # maps id to waypoint snapshot
self._current_edge_snapshots = {} # maps id to edge snapshot
self._current_annotation_name_to_wp_id = {}
Expand Down Expand Up @@ -257,14 +257,14 @@ def download_graph(self, download_path: str) -> typing.Tuple[bool, str]:
# Downloading, reproducing, distributing or otherwise using the SDK Software
# is subject to the terms and conditions of the Boston Dynamics Software
# Development Kit License (20191101-BDSDK-SL).
def _get_localization_state(self, *args) -> None:
def _get_localization_state(self, *args: typing.Any) -> None:
"""Get the current localization and state of the robot."""
state = self._graph_nav_client.get_localization_state()
self._logger.info(f"Got localization: \n{str(state.localization)}")
odom_tform_body = get_odom_tform_body(state.robot_kinematics.transforms_snapshot)
self._logger.info(f"Got robot state in kinematic odometry frame: \n{str(odom_tform_body)}")

def set_initial_localization_fiducial(self, *args) -> None:
def set_initial_localization_fiducial(self, *args: typing.Any) -> None:
"""Trigger localization when near a fiducial."""
robot_state = self._robot_state_client.get_robot_state()
current_odom_tform_body = get_odom_tform_body(robot_state.kinematic_state.transforms_snapshot).to_proto()
Expand All @@ -276,7 +276,7 @@ def set_initial_localization_fiducial(self, *args) -> None:
ko_tform_body=current_odom_tform_body,
)

def set_initial_localization_waypoint(self, *args) -> None:
def set_initial_localization_waypoint(self, *args: typing.Any) -> None:
"""Trigger localization to a waypoint."""
# Take the first argument as the localization waypoint.
if len(args) < 1:
Expand Down Expand Up @@ -308,15 +308,15 @@ def set_initial_localization_waypoint(self, *args) -> None:
ko_tform_body=current_odom_tform_body,
)

def _download_current_graph(self):
def _download_current_graph(self) -> map_pb2.Graph:
graph = self._graph_nav_client.download_graph()
if graph is None:
self._logger.error("Empty graph.")
return
self._current_graph = graph
return graph

def _download_full_graph(self, *args) -> None:
def _download_full_graph(self, *args: typing.Any) -> None:
"""Download the graph and snapshots from the robot."""
graph = self._graph_nav_client.download_graph()
if graph is None:
Expand All @@ -330,12 +330,12 @@ def _download_full_graph(self, *args) -> None:
self._download_and_write_waypoint_snapshots(graph.waypoints)
self._download_and_write_edge_snapshots(graph.edges)

def _write_full_graph(self, graph) -> None:
def _write_full_graph(self, graph: map_pb2.Graph) -> None:
"""Download the graph from robot to the specified, local filepath location."""
graph_bytes = graph.SerializeToString()
self._write_bytes(self._download_filepath, "/graph", graph_bytes)

def _download_and_write_waypoint_snapshots(self, waypoints) -> None:
def _download_and_write_waypoint_snapshots(self, waypoints: typing.List[map_pb2.Waypoint]) -> None:
"""Download the waypoint snapshots from robot to the specified, local filepath location."""
num_waypoint_snapshots_downloaded = 0
for waypoint in waypoints:
Expand All @@ -359,7 +359,7 @@ def _download_and_write_waypoint_snapshots(self, waypoints) -> None:
)
)

def _download_and_write_edge_snapshots(self, edges) -> None:
def _download_and_write_edge_snapshots(self, edges: typing.List[map_pb2.Edge]) -> None:
"""Download the edge snapshots from robot to the specified, local filepath location."""
num_edge_snapshots_downloaded = 0
num_to_download = 0
Expand All @@ -383,14 +383,14 @@ def _download_and_write_edge_snapshots(self, edges) -> None:
"Downloaded {} of the total {} edge snapshots.".format(num_edge_snapshots_downloaded, num_to_download)
)

def _write_bytes(self, filepath: str, filename: str, data) -> None:
def _write_bytes(self, filepath: str, filename: str, data: bytes) -> None:
"""Write data to a file."""
os.makedirs(filepath, exist_ok=True)
with open(filepath + filename, "wb+") as f:
f.write(data)
f.close()

def _list_graph_waypoint_and_edge_ids(self, *args):
def _list_graph_waypoint_and_edge_ids(self, *args: typing.Any):
"""List the waypoint ids and edge ids of the graph currently on the robot."""

# Download current graph
Expand All @@ -405,7 +405,7 @@ def _list_graph_waypoint_and_edge_ids(self, *args):
) = self._update_waypoints_and_edges(graph, localization_id, self._logger)
return self._current_annotation_name_to_wp_id, self._current_edges

def _upload_graph_and_snapshots(self, upload_filepath: str):
def _upload_graph_and_snapshots(self, upload_filepath: str) -> None:
"""Upload the graph and snapshots to the robot."""
self._lease = self._get_lease()
self._logger.info("Loading the graph from disk into local storage...")
Expand Down Expand Up @@ -608,7 +608,7 @@ def _navigate_route(self, waypoint_ids: typing.List[str]) -> typing.Tuple[bool,

return True, "Finished navigating route!"

def _clear_graph(self, *args) -> bool:
def _clear_graph(self, *args: typing.Any) -> bool:
"""Clear the state of the map on the robot, removing all waypoints and edges in the RAM of the robot"""
self._lease = self._get_lease()
result = self._graph_nav_client.clear_graph(lease=self._lease.lease_proto)
Expand Down Expand Up @@ -656,7 +656,7 @@ def _match_edge(
return None

def _auto_close_loops(
self, close_fiducial_loops: bool, close_odometry_loops: bool, *args
self, close_fiducial_loops: bool, close_odometry_loops: bool, *args: typing.Any
) -> typing.Tuple[bool, str]:
"""Automatically find and close all loops in the graph."""
response: map_processing_pb2.ProcessTopologyResponse = self._map_processing_client.process_topology(
Expand All @@ -678,7 +678,7 @@ def _auto_close_loops(
else:
return False, "Unknown error during map processing."

def _optimize_anchoring(self, *args) -> typing.Tuple[bool, str]:
def _optimize_anchoring(self, *args: typing.Any) -> typing.Tuple[bool, str]:
"""Call anchoring optimization on the server, producing a globally optimal reference frame for waypoints to be
expressed in.
"""
Expand Down Expand Up @@ -750,7 +750,7 @@ def _find_unique_waypoint_id(
graph: map_pb2.Graph,
name_to_id: typing.Dict[str, str],
logger: logging.Logger,
):
) -> typing.Optional[str]:
"""Convert either a 2 letter short code or an annotation name into the associated unique id."""
if len(short_code) != 2:
# Not a short code, check if it is an annotation name (instead of the waypoint id).
Expand Down
6 changes: 4 additions & 2 deletions spot_wrapper/spot_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
"back_depth_in_visual_frame",
]
ImageBundle = namedtuple("ImageBundle", ["frontleft", "frontright", "left", "right", "back"])
ImageWithHandBundle = namedtuple("ImageBundle", ["frontleft", "frontright", "left", "right", "back", "hand"])
ImageWithHandBundle = namedtuple(
"ImageBundle", ["frontleft", "frontright", "left", "right", "back", "hand"] # type: ignore[name-match]
)

IMAGE_SOURCES_BY_CAMERA = {
"frontleft": {
Expand Down Expand Up @@ -210,7 +212,7 @@ def __init__(
)

# Build image requests by camera
self._image_requests_by_camera = {}
self._image_requests_by_camera: typing.Dict[str, dict] = {}
for camera in IMAGE_SOURCES_BY_CAMERA:
if camera == "hand" and not self._robot.has_arm():
continue
Expand Down
Loading

0 comments on commit 963a3fe

Please sign in to comment.