Skip to content

Commit

Permalink
feat(output): allow output file override (#4)
Browse files Browse the repository at this point in the history
* chore(ignore): add csv and parquet files to ignore

* chore(ci): add poetry check

* feat(output): allow output file override
  • Loading branch information
kiran94 authored May 5, 2023
1 parent 7ba796e commit c4ecf8c
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ jobs:
with:
poetry-version: ${{ matrix.poetry-version }}

- name: Poetry Check
run: poetry check

- name: Install Dependencies
run: |
make export_requirements
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,6 @@ terraform.rc

requirements.txt
requirements-dev.txt

*.csv
*.parquet
6 changes: 4 additions & 2 deletions prfiesta/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
@click.option('-u', '--users', required=True, multiple=True, help='The GitHub Users to search for. Can be multiple (space delimited)')
@click.option('-t', '--token', help='The Authentication token to use')
@click.option('-x', '--url', help='The URL of the Git provider to use')
@click.option('-o', '--output_type', type=click.Choice(['csv', 'parquet']), default='csv', help='The output format')
@click.option('-o', '--output', default=None, help='The output location')
@click.option('-ot', '--output_type', type=click.Choice(['csv', 'parquet']), default='csv', help='The output format')
@click.option('--after', type=click.DateTime(formats=['%Y-%m-%d']), help='Only search for pull requests after this date e.g 2023-01-01')
@click.option('--before', type=click.DateTime(formats=['%Y-%m-%d']), help='Only search for pull requests before this date e.g 2023-04-30')
def main(**kwargs) -> None:

users: tuple[str] = kwargs.get('users')
token: str = kwargs.get('token') or github_environment.get_token()
url: str = kwargs.get('url') or github_environment.get_url()
output: str = kwargs.get('output')
output_type: str = kwargs.get('output_type')
before: datetime = kwargs.get('before')
after: datetime = kwargs.get('after')
Expand All @@ -44,7 +46,7 @@ def main(**kwargs) -> None:
logger.info('Found [bold green]%s[/bold green] pull requests!', pr_frame.shape[0])

if not pr_frame.empty:
output_frame(pr_frame, output_type, spinner=spinner)
output_frame(pr_frame, output_type, spinner=spinner, output_name=output)

if __name__ == '__main__': # pragma: nocover
main()
19 changes: 10 additions & 9 deletions prfiesta/output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from datetime import datetime
from typing import Literal

Expand All @@ -10,19 +9,21 @@

logger = logging.getLogger(__name__)

output_directory = 'output'
OUTPUT_TYPE = Literal['csv', 'parquet']


def output_frame(frame: pd.DataFrame, output_type: OUTPUT_TYPE, spinner: Spinner, output_name: str = 'export', timestamp: datetime = None) -> None:
def output_frame(
frame: pd.DataFrame,
output_type: OUTPUT_TYPE,
spinner: Spinner,
output_name: str = None,
timestamp: datetime = None) -> None:

if not timestamp:
timestamp = datetime.now()
if not output_name:
if not timestamp:
timestamp = datetime.now()

os.makedirs(output_directory, exist_ok=True)

output_name = str(output_name) + '.' + timestamp.strftime('%Y-%m-%d_%H:%M:%S') + '.' + output_type
output_name = os.path.join(output_directory, output_name)
output_name = f"export.{timestamp.strftime('%Y-%m-%d_%H:%M:%S')}.{output_type}"

spinner.update(text=f'Writing export to {output_name}', style=SPINNER_STYLE)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def test_main(
assert mock_collector.return_value.collect.call_args_list == expected_collect_params

if not collect_response.empty:
assert mock_output_frame.call_args_list == [call(collect_response, expected_output_type, spinner=mock_spinner.return_value)]
assert mock_output_frame.call_args_list == [call(collect_response, expected_output_type, spinner=mock_spinner.return_value, output_name=None)]

assert result.exit_code == 0
16 changes: 5 additions & 11 deletions tests/test_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from datetime import datetime
from unittest.mock import ANY, Mock, call, patch
from unittest.mock import ANY, Mock, call

import pytest

Expand All @@ -13,36 +12,31 @@
('csv', None),
('parquet', None),
])
@patch('prfiesta.output.os')
def test_output_frame(mock_os: Mock, output_type: str, timestamp: datetime) -> None:
def test_output_frame(output_type: str, timestamp: datetime) -> None:

mock_frame: Mock = Mock()
mock_spinner: Mock = Mock()
mock_os.path.join = os.path.join

output_frame(mock_frame, output_type, mock_spinner, timestamp=timestamp)

assert [call('output', exist_ok=True)] == mock_os.makedirs.call_args_list
assert mock_spinner.update.called

if timestamp and output_type == 'csv':
assert [call('output/export.2021-01-01_00:00:00.csv', index=False)] == mock_frame.to_csv.call_args_list
assert [call('export.2021-01-01_00:00:00.csv', index=False)] == mock_frame.to_csv.call_args_list

elif timestamp and output_type == 'parquet':
assert [call('output/export.2021-01-01_00:00:00.parquet', index=False)] == mock_frame.to_parquet.call_args_list
assert [call('export.2021-01-01_00:00:00.parquet', index=False)] == mock_frame.to_parquet.call_args_list

elif not timestamp and output_type == 'csv':
assert [call(ANY, index=False)] == mock_frame.to_csv.call_args_list

elif not timestamp and output_type == 'parquet':
assert [call(ANY, index=False)] == mock_frame.to_parquet.call_args_list

@patch('prfiesta.output.os')
def test_output_frame_unknown_type(mock_os: Mock) -> None:
def test_output_frame_unknown_type() -> None:

mock_frame: Mock = Mock()
mock_spinner: Mock = Mock()
mock_os.path.join = os.path.join

with pytest.raises(ValueError, match='unknown output_type'):
output_frame(mock_frame, 'unknown_type', mock_spinner, timestamp=datetime(2021, 1, 1))

0 comments on commit c4ecf8c

Please sign in to comment.