Skip to content

Commit ab6c77d

Browse files
committed
Merge branch 'main' of github.com:krassowski/docstring-to-markdown
2 parents 28c94b2 + 51c56f7 commit ab6c77d

File tree

2 files changed

+112
-11
lines changed

2 files changed

+112
-11
lines changed

docstring_to_markdown/rst.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,26 @@ class State(IntEnum):
198198
PARSING_ROWS = auto()
199199
FINISHED = auto()
200200

201-
outer_border_pattern = r'^(\s*)=+( +=+)+$'
201+
outer_border_pattern: str
202+
column_top_prefix: str
203+
column_top_border: str
204+
column_end_offset: int
202205

203206
_state: int
204207
_column_starts: List[int]
208+
_columns_end: int
205209
_columns: List[str]
206210
_rows: List[List[str]]
207211
_max_sizes: List[int]
208212
_indent: str
209213

214+
def __init__(self):
215+
self._reset_state()
216+
210217
def _reset_state(self):
211218
self._state = TableParser.State.AWAITS
212219
self._column_starts = []
220+
self._columns_end = -1
213221
self._columns = []
214222
self._rows = []
215223
self._max_sizes = []
@@ -222,11 +230,13 @@ def initiate_parsing(self, line: str, current_language: str) -> IBlockBeginning:
222230
self._reset_state()
223231
match = re.match(self.outer_border_pattern, line)
224232
assert match
225-
self._indent = match.group(1) or ''
233+
groups = match.groupdict()
234+
self._indent = groups['indent'] or ''
226235
self._column_starts = []
227-
previous = ' '
236+
self._columns_end = match.end('column')
237+
previous = self.column_top_prefix
228238
for i, char in enumerate(line):
229-
if char == '=' and previous == ' ':
239+
if char == self.column_top_border and previous == self.column_top_prefix:
230240
self._column_starts.append(i)
231241
previous = char
232242
self._max_sizes = [0 for i in self._column_starts]
@@ -245,17 +255,24 @@ def consume(self, line: str) -> None:
245255
# TODO: check integrity?
246256
self._state += 1
247257
elif self._state == states.PARSING_ROWS:
248-
match = re.match(self.outer_border_pattern, line)
249-
if match:
250-
self._state += 1
251-
else:
252-
self._rows.append(self._split(line))
258+
self._consume_row(line)
259+
260+
def _consume_row(self, line: str):
261+
match = re.match(self.outer_border_pattern, line)
262+
if match:
263+
self._state += 1
264+
else:
265+
self._rows.append(self._split(line))
253266

254267
def _split(self, line: str) -> List[str]:
255268
assert self._column_starts
256269
fragments = []
257270
for i, start in enumerate(self._column_starts):
258-
end = self._column_starts[i + 1] if i < len(self._column_starts) - 1 else None
271+
end = (
272+
self._column_starts[i + 1] + self.column_end_offset
273+
if i < len(self._column_starts) - 1 else
274+
self._columns_end
275+
)
259276
fragment = line[start:end].strip()
260277
self._max_sizes[i] = max(self._max_sizes[i], len(fragment))
261278
fragments.append(fragment)
@@ -281,6 +298,48 @@ def finish_consumption(self, final: bool) -> str:
281298
return result
282299

283300

301+
class SimpleTableParser(TableParser):
302+
outer_border_pattern = r'^(?P<indent>\s*)=+(?P<column> +=+)+$'
303+
column_top_prefix = ' '
304+
column_top_border = '='
305+
column_end_offset = 0
306+
307+
308+
class GridTableParser(TableParser):
309+
outer_border_pattern = r'^(?P<indent>\s*)(?P<column>\+-+)+\+$'
310+
column_top_prefix = '+'
311+
column_top_border = '-'
312+
column_end_offset = -1
313+
314+
_expecting_row_content: bool
315+
316+
def _reset_state(self):
317+
super()._reset_state()
318+
self._expecting_row_content = True
319+
320+
def _is_correct_row(self, line: str) -> bool:
321+
stripped = line.lstrip()
322+
if self._expecting_row_content:
323+
return stripped.startswith('|')
324+
else:
325+
return stripped.startswith('+-')
326+
327+
def can_consume(self, line: str) -> bool:
328+
return (
329+
bool(self._state != TableParser.State.FINISHED)
330+
and
331+
(self._state != TableParser.State.PARSING_ROWS or self._is_correct_row(line))
332+
)
333+
334+
def _consume_row(self, line: str):
335+
if self._is_correct_row(line):
336+
if self._expecting_row_content:
337+
self._rows.append(self._split(line))
338+
self._expecting_row_content = not self._expecting_row_content
339+
else:
340+
self._state += 1
341+
342+
284343
class BlockParser(IParser):
285344
enclosure = '```'
286345
follower: Union['IParser', None] = None
@@ -445,7 +504,8 @@ def initiate_parsing(self, line: str, current_language: str) -> IBlockBeginning:
445504
MathBlockParser(),
446505
ExplicitCodeBlockParser(),
447506
DoubleColonBlockParser(),
448-
TableParser()
507+
SimpleTableParser(),
508+
GridTableParser()
449509
]
450510

451511
RST_SECTIONS = {

tests/test_rst.py

+41
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,43 @@ def func(): pass
531531
If True, then sub-classes will be passed-through, otherwise
532532
"""
533533

534+
GRID_TABLE_IN_SKLEARN = """
535+
Attributes
536+
----------
537+
cv_results_ : dict of numpy (masked) ndarrays
538+
A dict with keys as column headers and values as columns, that can be
539+
imported into a pandas ``DataFrame``.
540+
For instance the below given table
541+
+------------+-----------+------------+-----------------+---+---------+
542+
|param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
543+
+============+===========+============+=================+===+=========+
544+
| 'poly' | -- | 2 | 0.80 |...| 2 |
545+
+------------+-----------+------------+-----------------+---+---------+
546+
| 'poly' | -- | 3 | 0.70 |...| 4 |
547+
+------------+-----------+------------+-----------------+---+---------+
548+
| 'rbf' | 0.1 | -- | 0.80 |...| 3 |
549+
+------------+-----------+------------+-----------------+---+---------+
550+
| 'rbf' | 0.2 | -- | 0.93 |...| 1 |
551+
+------------+-----------+------------+-----------------+---+---------+
552+
will be represented by a ``cv_results_`` dict
553+
"""
554+
555+
GRID_TABLE_IN_SKLEARN_MARKDOWN = """
556+
#### Attributes
557+
558+
- `cv_results_`: dict of numpy (masked) ndarrays
559+
A dict with keys as column headers and values as columns, that can be
560+
imported into a pandas ``DataFrame``.
561+
For instance the below given table
562+
| param_kernel | param_gamma | param_degree | split0_test_score | ... | rank_t... |
563+
| ------------ | ----------- | ------------ | ----------------- | --- | --------- |
564+
| 'poly' | -- | 2 | 0.80 | ... | 2 |
565+
| 'poly' | -- | 3 | 0.70 | ... | 4 |
566+
| 'rbf' | 0.1 | -- | 0.80 | ... | 3 |
567+
| 'rbf' | 0.2 | -- | 0.93 | ... | 1 |
568+
will be represented by a ``cv_results_`` dict
569+
"""
570+
534571
INTEGRATION = """
535572
Return a fixed frequency DatetimeIndex.
536573
@@ -661,6 +698,10 @@ def func(): pass
661698
'converts indented simple table': {
662699
'rst': SIMPLE_TABLE_IN_PARAMS,
663700
'md': SIMPLE_TABLE_IN_PARAMS_MARKDOWN
701+
},
702+
'converts indented grid table': {
703+
'rst': GRID_TABLE_IN_SKLEARN,
704+
'md': GRID_TABLE_IN_SKLEARN_MARKDOWN
664705
}
665706
}
666707

0 commit comments

Comments
 (0)