diff --git a/datalad_tabby/io/load.py b/datalad_tabby/io/load.py index 332c0e4..681cacb 100644 --- a/datalad_tabby/io/load.py +++ b/datalad_tabby/io/load.py @@ -10,6 +10,8 @@ List, ) +from charset_normalizer import from_path as cs_from_path + from .load_utils import ( _assign_context, _compact_obj, @@ -28,6 +30,7 @@ def load_tabby( jsonld: bool = True, recursive: bool = True, cpaths: List | None = None, + encoding: str | None = None, ) -> Dict | List: """Load a tabby (TSV) record as structured (JSON(-LD)) data @@ -48,11 +51,14 @@ def load_tabby( With the ``jsonld`` flag, a declared or default JSON-LD context is loaded and inserted into the record. + + Encoding used when reading tsv files can be specified as ``encoding``. """ ldr = _TabbyLoader( jsonld=jsonld, recursive=recursive, cpaths=cpaths, + encoding=encoding, ) return ldr(src=src, single=single) @@ -63,6 +69,7 @@ def __init__( jsonld: bool = True, recursive: bool = True, cpaths: List[Path] | None = None, + encoding: str | None = None, ): std_convention_path = Path(__file__).parent / 'conventions' if cpaths is None: @@ -70,6 +77,7 @@ def __init__( else: cpaths.append(std_convention_path) self._cpaths = cpaths + self._encoding = encoding self._jsonld = jsonld self._recursive = recursive @@ -94,8 +102,24 @@ def _load_single( src=src, trace=trace, ) + if self._encoding is not None: + tsv_obj = self._parse_tsv_single(src, encoding=self._encoding) + else: + try: + tsv_obj = self._parse_tsv_single(src) + except UnicodeDecodeError: + # by default Path.open() uses locale.getencoding() + # that didn't work, try guessing + encoding = cs_from_path(src).best().encoding + tsv_obj = self._parse_tsv_single(src, encoding=encoding) + + obj.update(tsv_obj) + + return self._postproc_obj(obj, src=src, trace=trace) - with src.open(newline='') as tsvfile: + def _parse_tsv_single(self, src: Path, encoding: str | None = None) -> Dict: + obj = {} + with src.open(newline='', encoding=encoding) as tsvfile: reader = csv.reader(tsvfile, delimiter='\t') # row_id is useful for error reporting for row_id, row in enumerate(reader): @@ -117,8 +141,7 @@ def _load_single( # we support "sequence" values via multi-column values # supporting two ways just adds unnecessary complexity obj[key] = val - - return self._postproc_obj(obj, src=src, trace=trace) + return obj def _load_many( self, @@ -144,26 +167,56 @@ def _load_many( # the table field/column names have purposefully _nothing_ # to do with any possibly loaded JSON data - fieldnames = None + if self._encoding is not None: + tsv_array = self._parse_tsv_many( + src, obj_tmpl, trace=trace, fieldnames=None, encoding=self._encoding + ) + else: + try: + tsv_array = self._parse_tsv_many( + src, obj_tmpl, trace=trace, fieldnames=None + ) + except UnicodeDecodeError: + # by default Path.open() uses locale.getencoding() + # that didn't work, try guessing + encoding = cs_from_path(src).best().encoding + tsv_array = self._parse_tsv_many( + src, obj_tmpl, trace=trace, fieldnames=None, encoding=encoding + ) + + array.extend(tsv_array) + + return array - with src.open(newline='') as tsvfile: + def _parse_tsv_many( + self, + src: Path, + obj_tmpl: Dict, + trace: List, + fieldnames: List | None = None, + encoding: str | None = None, + ) -> List[Dict]: + array = [] + with src.open(newline="", encoding=encoding) as tsvfile: # we cannot use DictReader -- we need to support identically named # columns - reader = csv.reader(tsvfile, delimiter='\t') + reader = csv.reader(tsvfile, delimiter="\t") # row_id is useful for error reporting for row_id, row in enumerate(reader): # row is a list of field, with only as many items # as this particular row has columns - if not len(row) \ - or row[0].startswith('#') \ - or all(v is None for v in row): + if ( + not len(row) + or row[0].startswith("#") + or all(v is None for v in row) + ): # skip empty rows, rows with no key, or rows with # a comment key continue if fieldnames is None: # the first non-ignored row defines the property names/keys # cut `val` short and remove trailing empty items - fieldnames = row[:_get_index_after_last_nonempty(row)] + fieldnames = row[: _get_index_after_last_nonempty(row)] continue obj = obj_tmpl.copy() diff --git a/setup.cfg b/setup.cfg index 8b06c8c..fe2b49f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,7 @@ install_requires = datalad >= 0.18.0 datalad-next @ git+https://github.com/datalad/datalad-next.git@main datalad-metalad + charset-normalizer openpyxl pyld packages = find_namespace: