Skip to content

Commit 644baf3

Browse files
committed
Add test covering __getitem__
1 parent 7d50f86 commit 644baf3

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

tests/test_class_basics.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,8 @@ fn test_runtime_parametrization() {
740740
py_run!(
741741
py,
742742
ty,
743-
"import types; assert ty.__class_getitem__((int,)) == types.GenericAlias(ty, (int,))"
743+
"import types;
744+
assert ty.__class_getitem__((int,)) == types.GenericAlias(ty, (int,))"
744745
);
745746
});
746747
}

tests/test_sequence.rs

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![cfg(feature = "macros")]
22

33
use pyo3::exceptions::{PyIndexError, PyValueError};
4-
use pyo3::types::{IntoPyDict, PyList, PyMapping, PySequence};
4+
use pyo3::types::{IntoPyDict, PyInt, PyList, PyMapping, PySequence};
55
use pyo3::{ffi, prelude::*};
66

77
use pyo3::py_run;
@@ -255,15 +255,15 @@ fn test_inplace_repeat() {
255255
// Check that #[pyo3(get, set)] works correctly for Vec<PyObject>
256256

257257
#[pyclass]
258-
struct GenericList {
258+
struct AnyObjectList {
259259
#[pyo3(get, set)]
260260
items: Vec<PyObject>,
261261
}
262262

263263
#[test]
264-
fn test_generic_list_get() {
264+
fn test_any_object_list_get() {
265265
Python::with_gil(|py| {
266-
let list = GenericList {
266+
let list = AnyObjectList {
267267
items: [1i32, 2, 3]
268268
.iter()
269269
.map(|i| i.into_pyobject(py).unwrap().into_any().unbind())
@@ -277,9 +277,9 @@ fn test_generic_list_get() {
277277
}
278278

279279
#[test]
280-
fn test_generic_list_set() {
280+
fn test_any_object_list_set() {
281281
Python::with_gil(|py| {
282-
let list = Bound::new(py, GenericList { items: vec![] }).unwrap();
282+
let list = Bound::new(py, AnyObjectList { items: vec![] }).unwrap();
283283

284284
py_run!(py, list, "list.items = [1, 2, 3]");
285285
assert!(list
@@ -367,3 +367,58 @@ fn sequence_length() {
367367
unsafe { ffi::PyErr_Clear() };
368368
})
369369
}
370+
371+
#[cfg(Py_3_9)]
372+
#[pyclass(generic, sequence)]
373+
struct GenericList {
374+
#[pyo3(get, set)]
375+
items: Vec<PyObject>,
376+
}
377+
378+
#[cfg(Py_3_9)]
379+
#[pymethods]
380+
impl GenericList {
381+
fn __len__(&self) -> usize {
382+
self.items.len()
383+
}
384+
385+
fn __getitem__(&self, idx: isize) -> PyResult<PyObject> {
386+
match self.items.get(idx as usize) {
387+
Some(x) => pyo3::Python::with_gil(|py| Ok(x.clone_ref(py))),
388+
None => Err(PyIndexError::new_err("Index out of bounds")),
389+
}
390+
}
391+
}
392+
393+
#[cfg(Py_3_9)]
394+
#[test]
395+
fn test_generic_both_subscriptions_types() {
396+
use std::convert::Infallible;
397+
398+
Python::with_gil(|py| {
399+
let l = Bound::new(
400+
py,
401+
GenericList {
402+
items: vec![1, 2, 3]
403+
.iter()
404+
.map(|x| -> PyObject {
405+
let x: Result<Bound<'_, PyInt>, Infallible> = x.into_pyobject(py);
406+
return x.unwrap().into_any().unbind();
407+
})
408+
.collect(),
409+
},
410+
)
411+
.unwrap();
412+
let ty = py.get_type::<GenericList>();
413+
py_assert!(py, l, "l[0] == 1");
414+
py_run!(
415+
py,
416+
ty,
417+
"import types;
418+
import typing;
419+
IntOrNone: typing.Alias = typing.Union[int, None];
420+
assert ty[IntOrNone] == types.GenericAlias(ty, (IntOrNone,))"
421+
);
422+
py_assert!(py, l, "list(reversed(l)) == [3, 2, 1]");
423+
});
424+
}

0 commit comments

Comments
 (0)