Skip to content

Commit c6c4bdf

Browse files
authored
Refactor rust code to submodules (#776)
1 parent a03edc3 commit c6c4bdf

File tree

11 files changed

+130
-85
lines changed

11 files changed

+130
-85
lines changed

deebot_client/commands/json/map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from deebot_client.events.map import CachedMapInfoEvent
1919
from deebot_client.logging_filter import get_logger
2020
from deebot_client.message import HandlingResult, HandlingState, MessageBodyDataDict
21-
from deebot_client.rs import decompress_7z_base64_data
21+
from deebot_client.rs.util import decompress_7z_base64_data
2222

2323
from .common import JsonCommandWithMessageHandling
2424

deebot_client/map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from .exceptions import MapError
3535
from .logging_filter import get_logger
3636
from .models import Room
37-
from .rs import TracePoint, decompress_7z_base64_data, extract_trace_points
37+
from .rs.map import TracePoint, extract_trace_points
38+
from .rs.util import decompress_7z_base64_data
3839
from .util import (
3940
OnChangedDict,
4041
OnChangedList,

deebot_client/rs.pyi renamed to deebot_client/rs/map.pyi

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
from typing import Self
22

3-
def decompress_7z_base64_data(value: str) -> bytes:
4-
"""Decompress base64 decoded 7z compressed string."""
5-
63
class TracePoint:
74
"""Trace point."""
85

deebot_client/rs/util.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def decompress_7z_base64_data(value: str) -> bytes:
2+
"""Decompress base64 decoded 7z compressed string."""

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ required-imports = ["from __future__ import annotations"]
148148
"deebot_client/hardware/deebot/*" = [
149149
"N999", # Invalid module name
150150
]
151-
"deebot_client/rs.pyi" = [
151+
"deebot_client/rs/*.pyi" = [
152152
"PYI021", # docstring-in-stub
153153
]
154154

@@ -167,7 +167,11 @@ max-args = 7
167167
py-version = "3.13"
168168
ignore = ["tests"]
169169
fail-on = ["I"]
170-
extension-pkg-allow-list = ["deebot_client.rs"]
170+
extension-pkg-allow-list = [
171+
"deebot_client.rs",
172+
"deebot_client.rs.map",
173+
"deebot_client.rs.util",
174+
]
171175

172176
[tool.pylint.BASIC]
173177
good-names = ["i", "j", "k", "ex", "_", "T", "x", "y", "id", "tg"]

src/lib.rs

Lines changed: 23 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,31 @@
1-
use std::error::Error;
2-
3-
use base64::{engine::general_purpose, Engine as _};
4-
use byteorder::{LittleEndian, ReadBytesExt};
5-
use pyo3::exceptions::PyValueError;
61
use pyo3::prelude::*;
7-
use std::io::Cursor;
8-
9-
fn _decompress_7z_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error>> {
10-
let mut bytes = general_purpose::STANDARD.decode(value)?;
11-
12-
if bytes.len() < 8 {
13-
return Err("Invalid 7z compressed data".into());
14-
}
15-
16-
for _ in 0..=3 {
17-
bytes.insert(8, 0);
18-
}
19-
20-
Ok(lzma::decompress(&bytes)?)
21-
}
22-
23-
/// Decompress base64 decoded 7z compressed string.
24-
#[pyfunction]
25-
fn decompress_7z_base64_data(value: String) -> Result<Vec<u8>, PyErr> {
26-
Ok(_decompress_7z_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))?)
27-
}
28-
29-
/// Trace point
30-
#[pyclass]
31-
struct TracePoint {
32-
#[pyo3(get)]
33-
x: i16,
34-
35-
#[pyo3(get)]
36-
y: i16,
37-
38-
#[pyo3(get)]
39-
connected: bool,
40-
}
41-
42-
#[pymethods]
43-
impl TracePoint {
44-
#[new]
45-
fn new(x: i16, y: i16, connected: bool) -> Self {
46-
TracePoint { x, y, connected }
47-
}
48-
}
492

50-
fn process_trace_points(trace_points: &[u8]) -> Result<Vec<TracePoint>, Box<dyn Error>> {
51-
let mut trace_values = Vec::new();
52-
for i in (0..trace_points.len()).step_by(5) {
53-
if i + 4 >= trace_points.len() {
54-
return Err("Invalid trace points length".into());
55-
}
56-
57-
let mut cursor = Cursor::new(&trace_points[i..i + 4]);
58-
let x = cursor.read_i16::<LittleEndian>()?;
59-
let y = cursor.read_i16::<LittleEndian>()?;
60-
61-
// Determine connection status
62-
let connected = (trace_points[i + 4] >> 7 & 1) == 0;
63-
64-
trace_values.push(TracePoint { x, y, connected });
65-
}
66-
Ok(trace_values)
67-
}
68-
69-
#[pyfunction]
70-
/// Extract trace points from 7z compressed data string.
71-
fn extract_trace_points(value: String) -> Result<Vec<TracePoint>, PyErr> {
72-
let decompressed_data = decompress_7z_base64_data(value)?;
73-
Ok(process_trace_points(&decompressed_data)
74-
.map_err(|err| PyValueError::new_err(err.to_string()))?)
75-
}
3+
mod map;
4+
mod util;
765

776
/// Deebot client written in Rust
787
#[pymodule]
798
fn rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
80-
m.add_function(wrap_pyfunction!(decompress_7z_base64_data, m)?)?;
81-
m.add_function(wrap_pyfunction!(extract_trace_points, m)?)?;
82-
m.add_class::<TracePoint>()?;
9+
register_child_module(m, "map", map::init_module)?;
10+
register_child_module(m, "util", util::init_module)?;
8311
Ok(())
8412
}
13+
14+
fn register_child_module(
15+
parent_module: &Bound<'_, PyModule>,
16+
name: &str,
17+
func: fn(&Bound<'_, PyModule>) -> PyResult<()>,
18+
) -> PyResult<()> {
19+
let child_module = PyModule::new(parent_module.py(), name)?;
20+
func(&child_module)?;
21+
22+
// https://github.com/PyO3/pyo3/issues/1517#issuecomment-808664021
23+
// https://github.com/PyO3/pyo3/issues/759
24+
let _ = Python::with_gil(|py| {
25+
py.import("sys")?
26+
.getattr("modules")?
27+
.set_item(&format!("deebot_client.rs.{}", name), &child_module)
28+
});
29+
30+
parent_module.add_submodule(&child_module)
31+
}

src/map.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use std::error::Error;
2+
3+
use super::util::decompress_7z_base64_data;
4+
use byteorder::{LittleEndian, ReadBytesExt};
5+
use pyo3::exceptions::PyValueError;
6+
use pyo3::prelude::*;
7+
use std::io::Cursor;
8+
9+
/// Trace point
10+
#[pyclass]
11+
struct TracePoint {
12+
#[pyo3(get)]
13+
x: i16,
14+
15+
#[pyo3(get)]
16+
y: i16,
17+
18+
#[pyo3(get)]
19+
connected: bool,
20+
}
21+
22+
#[pymethods]
23+
impl TracePoint {
24+
#[new]
25+
fn new(x: i16, y: i16, connected: bool) -> Self {
26+
TracePoint { x, y, connected }
27+
}
28+
}
29+
30+
fn process_trace_points(trace_points: &[u8]) -> Result<Vec<TracePoint>, Box<dyn Error>> {
31+
let mut trace_values = Vec::new();
32+
for i in (0..trace_points.len()).step_by(5) {
33+
if i + 4 >= trace_points.len() {
34+
return Err("Invalid trace points length".into());
35+
}
36+
37+
let mut cursor = Cursor::new(&trace_points[i..i + 4]);
38+
let x = cursor.read_i16::<LittleEndian>()?;
39+
let y = cursor.read_i16::<LittleEndian>()?;
40+
41+
// Determine connection status
42+
let connected = (trace_points[i + 4] >> 7 & 1) == 0;
43+
44+
trace_values.push(TracePoint { x, y, connected });
45+
}
46+
Ok(trace_values)
47+
}
48+
49+
fn extract_trace_points(value: String) -> Result<Vec<TracePoint>, Box<dyn Error>> {
50+
let decompressed_data = decompress_7z_base64_data(value)?;
51+
Ok(process_trace_points(&decompressed_data)?)
52+
}
53+
54+
#[pyfunction(name = "extract_trace_points")]
55+
/// Extract trace points from 7z compressed data string.
56+
fn python_extract_trace_points(value: String) -> Result<Vec<TracePoint>, PyErr> {
57+
Ok(extract_trace_points(value).map_err(|err| PyValueError::new_err(err.to_string()))?)
58+
}
59+
60+
pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
61+
m.add_function(wrap_pyfunction!(python_extract_trace_points, m)?)?;
62+
m.add_class::<TracePoint>()?;
63+
Ok(())
64+
}

src/util.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use std::error::Error;
2+
3+
use base64::{engine::general_purpose, Engine as _};
4+
use pyo3::exceptions::PyValueError;
5+
use pyo3::prelude::*;
6+
7+
pub fn decompress_7z_base64_data(value: String) -> Result<Vec<u8>, Box<dyn Error>> {
8+
let mut bytes = general_purpose::STANDARD.decode(value)?;
9+
10+
if bytes.len() < 8 {
11+
return Err("Invalid 7z compressed data".into());
12+
}
13+
14+
for _ in 0..=3 {
15+
bytes.insert(8, 0);
16+
}
17+
18+
Ok(lzma::decompress(&bytes)?)
19+
}
20+
21+
/// Decompress base64 decoded 7z compressed string.
22+
#[pyfunction(name = "decompress_7z_base64_data")]
23+
fn python_decompress_7z_base64_data(value: String) -> Result<Vec<u8>, PyErr> {
24+
Ok(decompress_7z_base64_data(value).map_err(|err| PyValueError::new_err(err.to_string()))?)
25+
}
26+
27+
pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
28+
m.add_function(wrap_pyfunction!(python_decompress_7z_base64_data, m)?)?;
29+
Ok(())
30+
}

tests/rs/__init__.py

Whitespace-only changes.

tests/test_rs.py renamed to tests/rs/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from deebot_client.rs import decompress_7z_base64_data
11+
from deebot_client.rs.util import decompress_7z_base64_data
1212

1313
if TYPE_CHECKING:
1414
from contextlib import AbstractContextManager

tests/test_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
_points_to_svg_path,
4949
)
5050
from deebot_client.models import Room
51-
from deebot_client.rs import TracePoint
51+
from deebot_client.rs.map import TracePoint
5252

5353
from .common import block_till_done
5454

0 commit comments

Comments
 (0)