Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to pyo3 v0.21 Bound API #23

Merged
merged 9 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ name = "libipld"
crate-type = ["rlib", "cdylib"]

[dependencies]
pyo3 = { version = "0.20", features = ["generate-import-lib", "anyhow"] }
#pyo3 = { version = "0.20", features = ["generate-import-lib", "anyhow"] }
pyo3 = { git = "https://github.com/PyO3/pyo3", branch = "main", features = ["generate-import-lib", "anyhow"] }
python3-dll-a = "0.2.7"
anyhow = "1.0.75"
futures = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions profiling/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ libipld = { path = ".." }
structopt = "0.3.26"
clap = "4.5.1"

[dependencies.pyo3]
version = "0.20"
#[dependencies.pyo3]
#version = "0.20"
98 changes: 51 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,43 @@ use pyo3::{PyObject, Python};
use pyo3::conversion::ToPyObject;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::pybacked::PyBackedStr;

fn car_header_to_pydict<'py>(py: Python<'py>, header: &CarHeader) -> &'py PyDict {
let dict_obj = PyDict::new(py);
fn car_header_to_pydict<'py>(py: Python<'py>, header: &CarHeader) -> Bound<'py, PyDict> {
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("version", header.version()).unwrap();

let roots = PyList::empty(py);
let roots = PyList::empty_bound(py);
header.roots().iter().for_each(|cid| {
let cid_obj = cid.to_string().to_object(py);
roots.append(cid_obj).unwrap();
});

dict_obj.set_item("roots", roots).unwrap();

dict_obj.into()
dict_obj
}

fn cid_hash_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> &'py PyDict {
fn cid_hash_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> {
let hash = cid.hash();
let dict_obj = PyDict::new(py);
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("code", hash.code()).unwrap();
dict_obj.set_item("size", hash.size()).unwrap();
dict_obj.set_item("digest", PyBytes::new(py, &hash.digest())).unwrap();
dict_obj.set_item("digest", PyBytes::new_bound(py, &hash.digest())).unwrap();

dict_obj.into()
dict_obj
}

fn cid_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> &'py PyDict {
let dict_obj = PyDict::new(py);
fn cid_to_pydict<'py>(py: Python<'py>, cid: &Cid) -> Bound<'py, PyDict> {
let dict_obj = PyDict::new_bound(py);

dict_obj.set_item("version", cid.version() as u64).unwrap();
dict_obj.set_item("codec", cid.codec()).unwrap();
dict_obj.set_item("hash", cid_hash_to_pydict(py, cid)).unwrap();

dict_obj.into()
dict_obj
}

fn decode_len(len: u64) -> Result<usize> {
Expand All @@ -65,15 +66,17 @@ fn map_key_cmp(a: &str, b: &str) -> std::cmp::Ordering {
}
}

fn sort_map_keys(keys: &PySequence, len: usize) -> Vec<(&str, usize)> {
fn sort_map_keys(keys: Bound<PySequence>, len: usize) -> Vec<(PyBackedStr, usize)> {
// Returns key and index.
let mut keys_str = Vec::with_capacity(len);
for i in 0..len {
let key: &PyString = keys.get_item(i).unwrap().downcast().unwrap();
keys_str.push((key.to_str().unwrap(), i));
let item = keys.get_item(i).unwrap();
let key = item.downcast::<PyString>().unwrap().to_owned();
let backed_str = PyBackedStr::try_from(key).unwrap();
keys_str.push((backed_str, i));
}

keys_str.sort_by(|a, b| {
keys_str.sort_by(|a, b| { // sort_unstable_by performs bad
let (s1, _) = a;
let (s2, _) = b;

Expand All @@ -90,24 +93,25 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
MajorKind::NegativeInt => (-1 - decode::read_uint(r, major)? as i64).to_object(py),
MajorKind::ByteString => {
let len = decode::read_uint(r, major)?;
PyBytes::new(py, &decode::read_bytes(r, len)?).into()
PyBytes::new_bound(py, &decode::read_bytes(r, len)?).into()
}
MajorKind::TextString => {
let len = decode::read_uint(r, major)?;
decode::read_str(r, len)?.to_object(py)
}
MajorKind::Array => {
let len = decode_len(decode::read_uint(r, major)?)?;
// TODO (MarshalX): how to init list with capacity?
let list = PyList::empty(py);
let list = PyList::empty_bound(py);

for _ in 0..len {
list.append(decode_dag_cbor_to_pyobject(py, r, deep + 1)?).unwrap();
}

list.into()
}
MajorKind::Map => {
let len = decode_len(decode::read_uint(r, major)?)?;
let dict = PyDict::new(py);
let dict = PyDict::new_bound(py);

let mut prev_key: Option<String> = None;
for _ in 0..len {
Expand Down Expand Up @@ -135,6 +139,7 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
let value = decode_dag_cbor_to_pyobject(py, r, deep + 1)?;
dict.set_item(key_py, value).unwrap();
}

dict.into()
}
MajorKind::Tag => {
Expand All @@ -157,7 +162,7 @@ fn decode_dag_cbor_to_pyobject<R: Read + Seek>(py: Python, r: &mut R, deep: usiz
Ok(py_object)
}

fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny, w: &mut W) -> Result<()> {
fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: Bound<'py, PyAny>, w: &mut W) -> Result<()> {
/* Order is important for performance!

Fast checks go first:
Expand All @@ -177,7 +182,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyBool>() {
let buf = if obj.is_true()? { [cbor::TRUE.into()] } else { [cbor::FALSE.into()] };
let buf = if obj.is_truthy()? { [cbor::TRUE.into()] } else { [cbor::FALSE.into()] };
w.write_all(&buf)?;

Ok(())
Expand All @@ -192,7 +197,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyList>() {
let seq: &PySequence = obj.downcast().unwrap();
let seq = obj.downcast::<PySequence>().unwrap();
let len = obj.len()?;

encode::write_u64(w, MajorKind::Array, len as u64)?;
Expand All @@ -203,15 +208,14 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyDict>() {
let map: &PyMapping = obj.downcast().unwrap();
let keys = map.keys()?;
let values = map.values()?;
let map = obj.downcast::<PyMapping>().unwrap();
let len = map.len()?;
let keys = sort_map_keys(map.keys()?, len);
let values = map.values()?;

encode::write_u64(w, MajorKind::Map, len as u64)?;

let sorted_keys = sort_map_keys(&keys, len);
for (key, i) in sorted_keys {
for (key, i) in keys {
let key_buf = key.as_bytes();
encode::write_u64(w, MajorKind::TextString, key_buf.len() as u64)?;
w.write_all(key_buf)?;
Expand All @@ -221,7 +225,7 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyFloat>() {
let f: &PyFloat = obj.downcast().unwrap();
let f = obj.downcast::<PyFloat>().unwrap();
let v = f.value();

if !v.is_finite() {
Expand All @@ -234,15 +238,15 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny

Ok(())
} else if obj.is_instance_of::<PyBytes>() {
let b: &PyBytes = obj.downcast().unwrap();
let b = obj.downcast::<PyBytes>().unwrap();
let l: u64 = b.len()? as u64;

encode::write_u64(w, MajorKind::ByteString, l)?;
w.write_all(b.as_bytes())?;

Ok(())
} else if obj.is_instance_of::<PyString>() {
let s: &PyString = obj.downcast().unwrap();
let s = obj.downcast::<PyString>().unwrap();

// FIXME (MarshalX): it's not efficient to try to parse it as CID
let cid = Cid::try_from(s.to_str()?);
Expand Down Expand Up @@ -271,9 +275,9 @@ fn encode_dag_cbor_from_pyobject<'py, W: Write>(py: Python<'py>, obj: &'py PyAny
}

#[pyfunction]
fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<&'py PyList> {
fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<PyObject> {
let mut reader = BufReader::new(Cursor::new(data));
let decoded_parts = PyList::empty(py);
let decoded_parts = PyList::empty_bound(py);

loop {
let py_object = decode_dag_cbor_to_pyobject(py, &mut reader, 0);
Expand All @@ -284,11 +288,11 @@ fn decode_dag_cbor_multi<'py>(py: Python<'py>, data: &[u8]) -> PyResult<&'py PyL
}
}

Ok(decoded_parts)
Ok(decoded_parts.into())
}

#[pyfunction]
pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &'py PyDict)> {
pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(PyObject, PyObject)> {
let car_response = executor::block_on(CarReader::new(data));
if let Err(e) = car_response {
return Err(get_err("Failed to decode CAR", e.to_string()));
Expand All @@ -297,7 +301,7 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &
let car = car_response.unwrap();

let header = car_header_to_pydict(py, car.header());
let parsed_blocks = PyDict::new(py);
let parsed_blocks = PyDict::new_bound(py);

let blocks: Vec<Result<(Cid, Vec<u8>), CarError>> = executor::block_on(car.stream().collect());
blocks.into_iter().for_each(|block| {
Expand All @@ -310,7 +314,7 @@ pub fn decode_car<'py>(py: Python<'py>, data: &[u8]) -> PyResult<(&'py PyDict, &
}
});

Ok((header, parsed_blocks))
Ok((header.into(), parsed_blocks.into()))
}

#[pyfunction]
Expand All @@ -324,22 +328,22 @@ fn decode_dag_cbor(py: Python, data: &[u8]) -> PyResult<PyObject> {
}

#[pyfunction]
fn encode_dag_cbor<'py>(py: Python<'py>, data: &PyAny) -> PyResult<&'py PyBytes> {
fn encode_dag_cbor<'py>(py: Python<'py>, data: Bound<'py, PyAny>) -> PyResult<PyObject> {
let mut buf = &mut BufWriter::new(Vec::new());
if let Err(e) = encode_dag_cbor_from_pyobject(py, data, &mut buf) {
return Err(get_err("Failed to encode DAG-CBOR", e.to_string()));
}
if let Err(e) = buf.flush() {
return Err(get_err("Failed to flush buffer", e.to_string()));
}
Ok(PyBytes::new(py, &buf.get_ref()))
Ok(PyBytes::new_bound(py, &buf.get_ref()).into())
}

#[pyfunction]
fn decode_cid(py: Python, data: String) -> PyResult<&PyDict> {
fn decode_cid(py: Python, data: String) -> PyResult<PyObject> {
let cid = Cid::try_from(data.as_str());
if let Ok(cid) = cid {
Ok(cid_to_pydict(py, &cid))
Ok(cid_to_pydict(py, &cid).into())
} else {
Err(get_err("Failed to decode CID", cid.unwrap_err().to_string()))
}
Expand All @@ -349,23 +353,23 @@ fn decode_cid(py: Python, data: String) -> PyResult<&PyDict> {
fn decode_multibase(py: Python, data: String) -> PyResult<(char, PyObject)> {
let base = multibase::decode(data);
if let Ok((base, data)) = base {
Ok((base.code(), PyBytes::new(py, &data).into()))
Ok((base.code(), PyBytes::new_bound(py, &data).into()))
} else {
Err(get_err("Failed to decode multibase", base.unwrap_err().to_string()))
}
}

#[pyfunction]
fn encode_multibase(code: char, data: &PyAny) -> PyResult<String> {
fn encode_multibase(code: char, data: Bound<PyAny>) -> PyResult<String> {
let data_bytes: &[u8];
if data.is_instance_of::<PyBytes>() {
let b: &PyBytes = data.downcast().unwrap();
let b = data.downcast::<PyBytes>().unwrap();
data_bytes = b.as_bytes();
} else if data.is_instance_of::<PyByteArray>() {
let b: &PyByteArray = data.downcast().unwrap();
data_bytes = unsafe { b.as_bytes() };
let ba = data.downcast::<PyByteArray>().unwrap();
data_bytes = unsafe { ba.as_bytes() };
} else if data.is_instance_of::<PyString>() {
let s: &PyString = data.downcast().unwrap();
let s = data.downcast::<PyString>().unwrap();
data_bytes = s.to_str()?.as_bytes();
} else {
return Err(get_err("Failed to encode multibase", "Unsupported data type".to_string()));
Expand All @@ -384,7 +388,7 @@ fn get_err(msg: &str, err: String) -> PyErr {
}

#[pymodule]
fn libipld(_py: Python, m: &PyModule) -> PyResult<()> {
fn libipld(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(decode_cid, m)?)?;
m.add_function(wrap_pyfunction!(decode_car, m)?)?;
m.add_function(wrap_pyfunction!(decode_dag_cbor, m)?)?;
Expand Down
Loading