Skip to content

Commit

Permalink
[mypy] nncf/common/logging (#3181)
Browse files Browse the repository at this point in the history
### Changes

Enable mypy for `nncf/common/logging`
Add method `update` to `track` to use `pbar.update(advance=1)` instead
of `pbar.progress.update(pbar.task, advance=1)`
  • Loading branch information
AlexanderDokuchaev authored Jan 15, 2025
1 parent 092da80 commit 5726fcb
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 33 deletions.
26 changes: 19 additions & 7 deletions nncf/common/logging/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from logging import Logger
from typing import Generic, Iterable, Iterator, Optional, TypeVar

from nncf.common.logging import nncf_logger

TObj = TypeVar("TObj")


class ProgressBar:
class ProgressBar(Generic[TObj]):
"""
A basic progress bar specifically for the logging.
It does not print at the same line, instead it logs multiple lines intentionally to avoid multiprocessing issues.
Expand All @@ -24,13 +29,20 @@ class ProgressBar:
:param total: the expected total number of iterations
"""

def __init__(self, iterable, logger=nncf_logger, desc="", num_lines=10, total=None):
def __init__(
self,
iterable: Iterable[TObj],
logger: Logger = nncf_logger,
desc: str = "",
num_lines: int = 10,
total: Optional[int] = None,
):
self._logger = logger
self._iterable = iterable
self._desc = desc
self._num_lines = num_lines

self._index = 0
self._index: int = 0
self._width = 16
self._is_enabled = False
self._total = None
Expand All @@ -45,7 +57,7 @@ def __init__(self, iterable, logger=nncf_logger, desc="", num_lines=10, total=No

if iterable is not None and self._total is None:
try:
self._total = len(iterable)
self._total = len(iterable) # type: ignore[arg-type]
except (TypeError, AttributeError):
logger.error(
"Progress bar is disabled because the given iterable is invalid: "
Expand All @@ -63,15 +75,15 @@ def __init__(self, iterable, logger=nncf_logger, desc="", num_lines=10, total=No
self._step = max(1, self._total // (self._num_lines - 1))
self._is_enabled = True

def __iter__(self):
def __iter__(self) -> Iterator[TObj]:
for obj in self._iterable:
yield obj
if self._is_enabled:
self._print_next()

def _print_next(self):
def _print_next(self) -> None:
self._index += 1
if self._index > self._total:
if self._total is None or self._index > self._total:
return

if self._index % self._step == 0 or self._index == self._total:
Expand Down
60 changes: 41 additions & 19 deletions nncf/common/logging/track_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Callable, Iterable, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Union

from rich.console import Console
from rich.progress import BarColumn
from rich.progress import Column
from rich.progress import Progress
from rich.progress import ProgressColumn
from rich.progress import ProgressType
Expand All @@ -24,6 +24,7 @@
from rich.progress import TimeElapsedColumn
from rich.progress import TimeRemainingColumn
from rich.style import StyleType
from rich.table import Column
from rich.text import Text

INTEL_BLUE_COLOR = "#0068b5"
Expand All @@ -49,13 +50,13 @@ def render(self, task: Task) -> Text:


class TimeElapsedColumnWithStyle(TimeElapsedColumn):
def render(self, task: "Task") -> Text:
def render(self, task: Task) -> Text:
text = super().render(task)
return Text(text._text[0], style=INTEL_BLUE_COLOR)


class TimeRemainingColumnWithStyle(TimeRemainingColumn):
def render(self, task: "Task") -> Text:
def render(self, task: Task) -> Text:
text = super().render(task)
return Text(text._text[0], style=INTEL_BLUE_COLOR)

Expand All @@ -65,7 +66,7 @@ class WeightedProgress(Progress):
A class to perform a weighted progress tracking.
"""

def update(self, task_id: TaskID, **kwargs) -> None:
def update(self, task_id: TaskID, **kwargs: Any) -> None:
task = self._tasks[task_id]

advance = kwargs.get("advance", None)
Expand All @@ -84,7 +85,7 @@ def advance(self, task_id: TaskID, advance: float = 1) -> None:
advance = self.weighted_advance(task, advance)
super().advance(task_id, advance)

def reset(self, task_id: TaskID, **kwargs) -> None:
def reset(self, task_id: TaskID, **kwargs: Any) -> None:
task = self._tasks[task_id]

completed = kwargs.get("completed", None)
Expand All @@ -104,8 +105,8 @@ def weighted_advance(task: Task, advance: float) -> float:
if advance % 1 != 0:
raise Exception(f"Unexpected `advance` value: {advance}.")
advance = int(advance)
current_step = task.fields["completed_steps"]
weighted_advance = sum(task.fields["weights"][current_step : current_step + advance])
current_step: int = task.fields["completed_steps"]
weighted_advance: float = sum(task.fields["weights"][current_step : current_step + advance])
task.fields["completed_steps"] = current_step + advance
return weighted_advance

Expand All @@ -116,13 +117,13 @@ def get_weighted_completed(task: Task, completed: float) -> float:
"""
if completed % 1 != 0:
raise Exception(f"Unexpected `completed` value: {completed}.")
return sum(task.fields["weights"][: int(completed)])
return float(sum(task.fields["weights"][: int(completed)]))


class track:
class track(Generic[ProgressType]):
def __init__(
self,
sequence: Optional[Union[Sequence[ProgressType], Iterable[ProgressType]]] = None,
sequence: Union[Sequence[ProgressType], Iterable[ProgressType], None] = None,
description: str = "Working...",
total: Optional[float] = None,
auto_refresh: bool = True,
Expand All @@ -144,6 +145,19 @@ def __init__(
This function is very similar to rich.progress.track(), but with some customizations.
Usage:
```
arr = [1,2]
for i in track(arr, description="Processing..."):
print(i)
with track[None](total=len(arr), description="Processing...") as pbar:
for i in arr:
pbar.update(advance=1)
```
:param sequence: An iterable (must support "len") you wish to iterate over.
:param description: Description of the task to show next to the progress bar. Defaults to "Working".
:param total: Total number of steps. Default is len(sequence).
Expand All @@ -169,7 +183,7 @@ def __init__(
self.total = sum(self.weights) if self.weights is not None else total
self.description = description
self.update_period = update_period
self.task = None
self.task: Optional[TaskID] = None

self.columns: List[ProgressColumn] = (
[TextColumn("[progress.description]{task.description}")] if description else []
Expand Down Expand Up @@ -198,7 +212,7 @@ def __init__(
)
)

disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0)
disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0) # type: ignore[arg-type]

progress_cls = Progress if weights is None else WeightedProgress
self.progress = progress_cls(
Expand All @@ -211,7 +225,9 @@ def __init__(
disable=disable,
)

def __iter__(self) -> Iterable[ProgressType]:
def __iter__(self) -> Iterator[ProgressType]:
if self.sequence is None:
raise RuntimeError("__iter__ called without set sequence.")
with self:
yield from self.progress.track(
self.sequence,
Expand All @@ -221,16 +237,22 @@ def __iter__(self) -> Iterable[ProgressType]:
update_period=self.update_period,
)

def __enter__(self):
kwargs = {}
def __enter__(self) -> track[ProgressType]:
kwargs: Dict[str, Any] = {}
if self.weights is not None:
kwargs["weights"] = self.weights
kwargs["completed_steps"] = 0
self.task = self.progress.add_task(self.description, total=self.total, **kwargs)
self.progress.__enter__()
return self

def __exit__(self, *args):
def __exit__(self, *args: Any) -> None:
self.progress.__exit__(*args)
self.progress.remove_task(self.task)
self.task = None
if self.task is not None:
self.progress.remove_task(self.task)
self.task = None

def update(self, advance: float, **kwargs: Any) -> None:
if self.task is None:
raise RuntimeError("update is available only inside context manager.")
self.progress.update(self.task, advance=advance, **kwargs)
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
engine = factory.EngineFactory.create(model_with_outputs)
iterations_number = self._get_iterations_number()
processed_samples = 0
for input_data in track( # type: ignore
for input_data in track(
islice(self.dataset.get_inference_data(), iterations_number),
total=iterations_number,
description="Statistics collection",
Expand Down
4 changes: 2 additions & 2 deletions nncf/data/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def generate_text_data(
step_num = max(1, vocab_size // dataset_size)
ids_counter = 0

with track(total=dataset_size, description="Generating text data") as pbar:
with track[None](total=dataset_size, description="Generating text data") as pbar:
while len(generated_data) < dataset_size:
# Creating the input for pre-generate step
input_ids = torch.tensor([[ids_counter % vocab_size]]).to(model.device)
Expand All @@ -97,7 +97,7 @@ def generate_text_data(

ids_counter += step_num

pbar.progress.update(pbar.task, advance=1)
pbar.update(advance=1)
generated_data.extend(gen_text)

return generated_data
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ files = [
exclude = [
"nncf/common/composite_compression.py",
"nncf/common/compression.py",
"nncf/common/logging/progress_bar.py",
"nncf/common/logging/track_progress.py",
"nncf/common/pruning/clusterization.py",
"nncf/common/pruning/mask_propagation.py",
"nncf/common/pruning/model_analysis.py",
Expand Down
2 changes: 1 addition & 1 deletion tests/common/utils/test_progress_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ def test_track_context_manager(n, is_weighted):
with track(total=n, description="Progress...", weights=weights if is_weighted else None) as pbar:
for i in range(n):
assert pbar.progress._tasks[pbar.task].completed == (sum(weights[:i]) if is_weighted else i)
pbar.progress.update(pbar.task, advance=1)
pbar.update(advance=1)
2 changes: 1 addition & 1 deletion tests/post_training/pipelines/image_classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def process_result(request, userdata):
output_data = request.get_output_tensor().data
predicted_label = np.argmax(output_data, axis=1)
predictions[userdata] = predicted_label
pbar.progress.update(pbar.task, advance=1)
pbar.update(advance=1)

infer_queue.set_callback(process_result)

Expand Down

0 comments on commit 5726fcb

Please sign in to comment.