Skip to content

Commit 9807384

Browse files
committed
Changed reproducibility info to use path to Python file that is run
1 parent b96be4e commit 9807384

File tree

4 files changed

+105
-88
lines changed

4 files changed

+105
-88
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,12 @@ Specifically, Tap has a method called `get_reproducibility_info` that returns a
373373
- The time when the command was run
374374
- Ex. `Thu Aug 15 00:09:13 2019`
375375
- Git root
376-
- The root of the git repo containing the code
376+
- The root of the git repo containing the code that was run
377377
- Ex. `/Users/swansonk14/typed-argument-parser`
378378
- Git url
379379
- The url to the git repo, specifically pointing to the current git hash (i.e. the hash of HEAD in the local repo)
380380
- Ex. [https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998](https://github.com/swansonk14/typed-argument-parser/tree/446cf046631d6bdf7cab6daec93bf7a02ac00998)
381-
- Uncommited changes
381+
- Uncommitted changes
382382
- Whether there are any uncommitted changes in the git repo (i.e. whether the code is different from the code at the above git hash)
383383
- Ex. `True` or `False`
384384

tap/tap.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from copy import deepcopy
44
from functools import partial
55
import json
6+
from pathlib import Path
67
from pprint import pformat
78
from shlex import quote
89
import sys
@@ -15,12 +16,9 @@
1516
from tap.utils import (
1617
get_class_variables,
1718
get_argument_name,
18-
get_git_root,
1919
get_dest,
20-
get_git_url,
2120
get_origin,
22-
has_git,
23-
has_uncommitted_changes,
21+
GitInfo,
2422
is_option_arg,
2523
type_to_str,
2624
get_literals,
@@ -337,7 +335,7 @@ def configure(self) -> None:
337335
pass
338336

339337
@staticmethod
340-
def get_reproducibility_info() -> Dict[str, str]:
338+
def get_reproducibility_info(repo_path: Optional[str] = None) -> Dict[str, str]:
341339
"""Gets a dictionary of reproducibility information.
342340
343341
Reproducibility information always includes:
@@ -350,27 +348,37 @@ def get_reproducibility_info() -> Dict[str, str]:
350348
Ex. https://github.com/swansonk14/rationale-alignment/tree/<hash>
351349
- git_has_uncommitted_changes: Whether the current git repo has uncommitted changes.
352350
351+
:param repo_path: Path to the git repo to examine for reproducibility info.
352+
If None, uses the git repo of the Python file that is run.
353353
:return: A dictionary of reproducibility information.
354354
"""
355+
# Get the path to the Python file that is being run
356+
if repo_path is None:
357+
repo_path = (Path.cwd() / Path(sys.argv[0]).parent).resolve()
358+
355359
reproducibility = {
356360
'command_line': f'python {" ".join(quote(arg) for arg in sys.argv)}',
357361
'time': time.strftime('%c')
358362
}
359363

360-
if has_git():
361-
reproducibility['git_root'] = get_git_root()
362-
reproducibility['git_url'] = get_git_url(commit_hash=True)
363-
reproducibility['git_has_uncommitted_changes'] = has_uncommitted_changes()
364+
git_info = GitInfo(repo_path=repo_path)
365+
366+
if git_info.has_git():
367+
reproducibility['git_root'] = git_info.get_git_root()
368+
reproducibility['git_url'] = git_info.get_git_url(commit_hash=True)
369+
reproducibility['git_has_uncommitted_changes'] = git_info.has_uncommitted_changes()
364370

365371
return reproducibility
366372

367-
def _log_all(self) -> Dict[str, Any]:
373+
def _log_all(self, repo_path: Optional[str] = None) -> Dict[str, Any]:
368374
"""Gets all arguments along with reproducibility information.
369375
376+
:param repo_path: Path to the git repo to examine for reproducibility info.
377+
If None, uses the git repo of the Python file that is run.
370378
:return: A dictionary containing all arguments along with reproducibility information.
371379
"""
372380
arg_log = self.as_dict()
373-
arg_log['reproducibility'] = self.get_reproducibility_info()
381+
arg_log['reproducibility'] = self.get_reproducibility_info(repo_path=repo_path)
374382

375383
return arg_log
376384

@@ -590,12 +598,17 @@ def from_dict(self, args_dict: Dict[str, Any], skip_unsettable: bool = False) ->
590598

591599
return self
592600

593-
def save(self, path: str, with_reproducibility: bool = True, skip_unpicklable: bool = False) -> None:
601+
def save(self,
602+
path: str, with_reproducibility: bool = True,
603+
skip_unpicklable: bool = False,
604+
repo_path: Optional[str] = None) -> None:
594605
"""Saves the arguments and reproducibility information in JSON format, pickling what can't be encoded.
595606
596607
:param path: Path to the JSON file where the arguments will be saved.
597608
:param with_reproducibility: If True, adds a "reproducibility" field with information (e.g. git hash)
598609
to the JSON file.
610+
:param repo_path: Path to the git repo to examine for reproducibility info.
611+
If None, uses the git repo of the Python file that is run.
599612
:param skip_unpicklable: If True, does not save attributes whose values cannot be pickled.
600613
"""
601614
with open(path, 'w') as f:
@@ -605,14 +618,17 @@ def save(self, path: str, with_reproducibility: bool = True, skip_unpicklable: b
605618
def load(self,
606619
path: str,
607620
check_reproducibility: bool = False,
608-
skip_unsettable: bool = False) -> TapType:
621+
skip_unsettable: bool = False,
622+
repo_path: Optional[str] = None) -> TapType:
609623
"""Loads the arguments in JSON format. Note: Due to JSON, tuples are loaded as lists.
610624
611625
:param path: Path to the JSON file where the arguments will be loaded from.
612626
:param check_reproducibility: When True, raises an error if the loaded reproducibility
613627
information doesn't match the current reproducibility information.
614628
:param skip_unsettable: When True, skips attributes that cannot be set in the Tap object,
615629
e.g. properties without setters.
630+
:param repo_path: Path to the git repo to examine for reproducibility info.
631+
If None, uses the git repo of the Python file that is run.
616632
:return: Returns self.
617633
"""
618634
with open(path) as f:
@@ -621,7 +637,7 @@ def load(self,
621637
# Remove loaded reproducibility information since it is no longer valid
622638
saved_reproducibility_data = args_dict.pop('reproducibility', None)
623639
if check_reproducibility:
624-
current_reproducibility_data = self.get_reproducibility_info()
640+
current_reproducibility_data = self.get_reproducibility_info(repo_path=repo_path)
625641
enforce_reproducibility(saved_reproducibility_data, current_reproducibility_data, path)
626642

627643
self.from_dict(args_dict, skip_unsettable=skip_unsettable)

tap/utils.py

+55-53
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
PRIMITIVES = (str, int, float, bool)
3232

3333

34-
def check_output(command: List[str], suppress_stderr: bool = True) -> str:
34+
def check_output(command: List[str], suppress_stderr: bool = True, **kwargs) -> str:
3535
"""Runs subprocess.check_output and returns the result as a string.
3636
3737
:param command: A list of strings representing the command to run on the command line.
@@ -40,77 +40,79 @@ def check_output(command: List[str], suppress_stderr: bool = True) -> str:
4040
"""
4141
with open(os.devnull, 'w') as devnull:
4242
devnull = devnull if suppress_stderr else None
43-
output = subprocess.check_output(command, stderr=devnull).decode('utf-8').strip()
43+
output = subprocess.check_output(command, stderr=devnull, **kwargs).decode('utf-8').strip()
4444
return output
4545

4646

47-
def has_git() -> bool:
48-
"""Returns whether git is installed.
47+
class GitInfo:
48+
"""Class with helper methods for extracting information about a git repo."""
4949

50-
:return: True if git is installed, False otherwise.
51-
"""
52-
try:
53-
output = check_output(['git', 'rev-parse', '--is-inside-work-tree'])
54-
return output == 'true'
55-
except (FileNotFoundError, subprocess.CalledProcessError):
56-
return False
50+
def __init__(self, repo_path: str):
51+
self.repo_path = repo_path
5752

53+
def has_git(self) -> bool:
54+
"""Returns whether git is installed.
5855
59-
def get_git_root() -> str:
60-
"""Gets the root directory of the git repo where the command is run.
61-
62-
:return: The root directory of the current git repo.
63-
"""
64-
return check_output(['git', 'rev-parse', '--show-toplevel'])
65-
56+
:return: True if git is installed, False otherwise.
57+
"""
58+
try:
59+
output = check_output(['git', 'rev-parse', '--is-inside-work-tree'], cwd=self.repo_path)
60+
return output == 'true'
61+
except (FileNotFoundError, subprocess.CalledProcessError):
62+
return False
6663

67-
def get_git_url(commit_hash: bool = True) -> str:
68-
"""Gets the https url of the git repo where the command is run.
64+
def get_git_root(self) -> str:
65+
"""Gets the root directory of the git repo where the command is run.
6966
70-
:param commit_hash: If True, the url links to the latest local git commit hash.
71-
If False, the url links to the general git url.
72-
:return: The https url of the current git repo.
73-
"""
74-
# Get git url (either https or ssh)
75-
try:
76-
url = check_output(['git', 'remote', 'get-url', 'origin'])
77-
except subprocess.CalledProcessError:
78-
# For git versions <2.0
79-
url = check_output(['git', 'config', '--get', 'remote.origin.url'])
67+
:return: The root directory of the current git repo.
68+
"""
69+
return check_output(['git', 'rev-parse', '--show-toplevel'], cwd=self.repo_path)
8070

81-
# Remove .git at end
82-
url = url[:-len('.git')]
71+
def get_git_url(self, commit_hash: bool = True) -> str:
72+
"""Gets the https url of the git repo where the command is run.
8373
84-
# Convert ssh url to https url
85-
m = re.search('git@(.+):', url)
86-
if m is not None:
87-
domain = m.group(1)
88-
path = url[m.span()[1]:]
89-
url = f'https://{domain}/{path}'
74+
:param commit_hash: If True, the url links to the latest local git commit hash.
75+
If False, the url links to the general git url.
76+
:return: The https url of the current git repo.
77+
"""
78+
# Get git url (either https or ssh)
79+
try:
80+
url = check_output(['git', 'remote', 'get-url', 'origin'], cwd=self.repo_path)
81+
except subprocess.CalledProcessError:
82+
# For git versions <2.0
83+
url = check_output(['git', 'config', '--get', 'remote.origin.url'], cwd=self.repo_path)
9084

91-
if commit_hash:
92-
# Add tree and hash of current commit
93-
url = f'{url}/tree/{get_git_hash()}'
85+
# Remove .git at end
86+
url = url[:-len('.git')]
9487

95-
return url
88+
# Convert ssh url to https url
89+
m = re.search('git@(.+):', url)
90+
if m is not None:
91+
domain = m.group(1)
92+
path = url[m.span()[1]:]
93+
url = f'https://{domain}/{path}'
9694

95+
if commit_hash:
96+
# Add tree and hash of current commit
97+
url = f'{url}/tree/{self.get_git_hash()}'
9798

98-
def get_git_hash() -> str:
99-
"""Gets the git hash of HEAD of the git repo where the command is run.
99+
return url
100100

101-
:return: The git hash of HEAD of the current git repo.
102-
"""
103-
return check_output(['git', 'rev-parse', 'HEAD'])
101+
def get_git_hash(self) -> str:
102+
"""Gets the git hash of HEAD of the git repo where the command is run.
104103
104+
:return: The git hash of HEAD of the current git repo.
105+
"""
106+
return check_output(['git', 'rev-parse', 'HEAD'], cwd=self.repo_path)
105107

106-
def has_uncommitted_changes() -> bool:
107-
"""Returns whether there are uncommitted changes in the git repo where the command is run.
108+
def has_uncommitted_changes(self) -> bool:
109+
"""Returns whether there are uncommitted changes in the git repo where the command is run.
108110
109-
:return: True if there are uncommitted changes in the current git repo, False otherwise.
110-
"""
111-
status = check_output(['git', 'status'])
111+
:return: True if there are uncommitted changes in the current git repo, False otherwise.
112+
"""
113+
status = check_output(['git', 'status'], cwd=self.repo_path)
112114

113-
return not status.endswith(NO_CHANGES_STATUS)
115+
return not status.endswith(NO_CHANGES_STATUS)
114116

115117

116118
def type_to_str(type_annotation: Union[type, Any]) -> str:

tests/test_utils.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from collections import OrderedDict
22
import json
33
import os
4-
import platform
54
import subprocess
65
from tempfile import TemporaryDirectory
76
from typing import Any, Callable, List, Dict, Set, Tuple, Union
@@ -10,12 +9,9 @@
109
from typing_extensions import Literal
1110

1211
from tap.utils import (
13-
has_git,
1412
get_class_column,
1513
get_class_variables,
16-
get_git_root,
17-
get_git_url,
18-
has_uncommitted_changes,
14+
GitInfo,
1915
type_to_str,
2016
get_literals,
2117
TupleTypeEnforcer,
@@ -37,6 +33,7 @@ def setUp(self) -> None:
3733
subprocess.check_output(['touch', 'README.md'])
3834
subprocess.check_output(['git', 'add', 'README.md'])
3935
subprocess.check_output(['git', 'commit', '-m', 'Initial commit'])
36+
self.git_info = GitInfo(repo_path=self.temp_dir.name)
4037

4138
def tearDown(self) -> None:
4239
os.chdir(self.prev_dir)
@@ -49,72 +46,74 @@ def tearDown(self) -> None:
4946
self.temp_dir.cleanup()
5047

5148
def test_has_git_true(self) -> None:
52-
self.assertTrue(has_git())
49+
self.assertTrue(self.git_info.has_git())
5350

5451
def test_has_git_false(self) -> None:
5552
with TemporaryDirectory() as temp_dir_no_git:
5653
os.chdir(temp_dir_no_git)
57-
self.assertFalse(has_git())
54+
self.git_info.repo_path = temp_dir_no_git
55+
self.assertFalse(self.git_info.has_git())
56+
self.git_info.repo_path = self.temp_dir.name
5857
os.chdir(self.temp_dir.name)
5958

6059
def test_get_git_root(self) -> None:
6160
# Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private
62-
self.assertTrue(get_git_root().endswith(self.temp_dir.name.replace('\\', '/')))
61+
self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace('\\', '/')))
6362

6463
def test_get_git_root_subdir(self) -> None:
6564
subdir = os.path.join(self.temp_dir.name, 'subdir')
6665
os.makedirs(subdir)
6766
os.chdir(subdir)
6867

6968
# Ideally should be self.temp_dir.name == get_git_root() but the OS may add a prefix like /private
70-
self.assertTrue(get_git_root().endswith(self.temp_dir.name.replace('\\', '/')))
69+
self.assertTrue(self.git_info.get_git_root().endswith(self.temp_dir.name.replace('\\', '/')))
7170

7271
os.chdir(self.temp_dir.name)
7372

7473
def test_get_git_url_https(self) -> None:
75-
self.assertEqual(get_git_url(commit_hash=False), self.url)
74+
self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url)
7675

7776
def test_get_git_url_https_hash(self) -> None:
7877
url = f'{self.url}/tree/'
79-
self.assertEqual(get_git_url(commit_hash=True)[:len(url)], url)
78+
self.assertEqual(self.git_info.get_git_url(commit_hash=True)[:len(url)], url)
8079

8180
def test_get_git_url_ssh(self) -> None:
8281
subprocess.run(['git', 'remote', 'set-url', 'origin', '[email protected]:test_account/test_repo.git'])
83-
self.assertEqual(get_git_url(commit_hash=False), self.url)
82+
self.assertEqual(self.git_info.get_git_url(commit_hash=False), self.url)
8483

8584
def test_get_git_url_ssh_hash(self) -> None:
8685
subprocess.run(['git', 'remote', 'set-url', 'origin', '[email protected]:test_account/test_repo.git'])
8786
url = f'{self.url}/tree/'
88-
self.assertEqual(get_git_url(commit_hash=True)[:len(url)], url)
87+
self.assertEqual(self.git_info.get_git_url(commit_hash=True)[:len(url)], url)
8988

9089
def test_get_git_url_https_enterprise(self) -> None:
9190
true_url = 'https://github.tap.com/test_account/test_repo'
9291
subprocess.run(['git', 'remote', 'set-url', 'origin', f'{true_url}.git'])
93-
self.assertEqual(get_git_url(commit_hash=False), true_url)
92+
self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url)
9493

9594
def test_get_git_url_https_hash_enterprise(self) -> None:
9695
true_url = 'https://github.tap.com/test_account/test_repo'
9796
subprocess.run(['git', 'remote', 'set-url', 'origin', f'{true_url}.git'])
9897
url = f'{true_url}/tree/'
99-
self.assertEqual(get_git_url(commit_hash=True)[:len(url)], url)
98+
self.assertEqual(self.git_info.get_git_url(commit_hash=True)[:len(url)], url)
10099

101100
def test_get_git_url_ssh_enterprise(self) -> None:
102101
true_url = 'https://github.tap.com/test_account/test_repo'
103102
subprocess.run(['git', 'remote', 'set-url', 'origin', '[email protected]:test_account/test_repo.git'])
104-
self.assertEqual(get_git_url(commit_hash=False), true_url)
103+
self.assertEqual(self.git_info.get_git_url(commit_hash=False), true_url)
105104

106105
def test_get_git_url_ssh_hash_enterprise(self) -> None:
107106
true_url = 'https://github.tap.com/test_account/test_repo'
108107
subprocess.run(['git', 'remote', 'set-url', 'origin', '[email protected]:test_account/test_repo.git'])
109108
url = f'{true_url}/tree/'
110-
self.assertEqual(get_git_url(commit_hash=True)[:len(url)], url)
109+
self.assertEqual(self.git_info.get_git_url(commit_hash=True)[:len(url)], url)
111110

112111
def test_has_uncommitted_changes_false(self) -> None:
113-
self.assertFalse(has_uncommitted_changes())
112+
self.assertFalse(self.git_info.has_uncommitted_changes())
114113

115114
def test_has_uncommited_changes_true(self) -> None:
116115
subprocess.run(['touch', 'main.py'])
117-
self.assertTrue(has_uncommitted_changes())
116+
self.assertTrue(self.git_info.has_uncommitted_changes())
118117

119118

120119
class TypeToStrTests(TestCase):

0 commit comments

Comments
 (0)