Skip to content

Commit 8022cf5

Browse files
committed
Fix FromPyObject for SkipSet.
Fixes #51 by allowing all string-iterables in FromPyObject implementation for SkipSet.
1 parent 6ff440b commit 8022cf5

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

src/embeddings.rs

+7-13
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use ndarray::Array2;
1616
use numpy::{IntoPyArray, NpyDataType, PyArray1, PyArray2, ToPyArray};
1717
use pyo3::class::iter::PyIterProtocol;
1818
use pyo3::prelude::*;
19-
use pyo3::types::{PyAny, PyList, PySet, PyTuple};
20-
use pyo3::{exceptions, PyMappingProtocol, PyTypeInfo};
19+
use pyo3::types::{PyAny, PyTuple};
20+
use pyo3::{exceptions, PyMappingProtocol};
2121
use toml::{self, Value};
2222

2323
use crate::{EmbeddingsWrap, PyEmbeddingIterator, PyVocab, PyWordSimilarity};
@@ -463,18 +463,12 @@ impl<'a> FromPyObject<'a> for Skips<'a> {
463463
fn extract(ob: &'a PyAny) -> Result<Self, PyErr> {
464464
let mut set = ob
465465
.len()
466-
.map(|len| HashSet::with_capacity(len))
466+
.map(HashSet::with_capacity)
467467
.unwrap_or_default();
468-
469-
let iter = if <PySet as PyTypeInfo>::is_instance(ob) {
470-
ob.iter().unwrap()
471-
} else if <PyList as PyTypeInfo>::is_instance(ob) {
472-
ob.iter().unwrap()
473-
} else {
474-
return Err(exceptions::TypeError::py_err("Iterable expected"));
475-
};
476-
477-
for el in iter {
468+
for el in ob
469+
.iter()
470+
.map_err(|_| exceptions::TypeError::py_err("Iterable expected"))?
471+
{
478472
set.insert(
479473
el?.extract()
480474
.map_err(|_| exceptions::TypeError::py_err("Expected String"))?,

0 commit comments

Comments
 (0)