Skip to content

Commit

Permalink
Merge pull request #52 from Lucs1590/set-key
Browse files Browse the repository at this point in the history
Set OpenAI Key
  • Loading branch information
Lucs1590 authored Dec 14, 2024
2 parents bae22bd + bd28ed4 commit 91791e5
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .cz.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "cz_conventional_commits"
tag_format = "$version"
version_scheme = "semver"
version = "1.0.0"
version = "1.1.0"
update_changelog_on_bump = true
major_version_zero = true
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name="strava-to-trainingpeaks",
version="1.0.0",
version="1.1.0",
author="Lucas de Brito Silva",
author_email="[email protected]",
description="A tool to sync Strava activities with TrainingPeaks, with the OpenAI API creating the workout descriptions.",
Expand Down
15 changes: 14 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def main():
logger.info("Validating the TCX file")
_, tcx_data = validate_tcx_file(file_path)
if ask_llm_analysis():
check_openai_key()
plan = ask_training_plan()
language = ask_desired_language()
logger.info("Performing LLM analysis")
Expand Down Expand Up @@ -94,7 +95,8 @@ def ask_activity_id() -> str:


def download_tcx_file(activity_id: str, sport: str) -> None:
url = f"https://www.strava.com/activities/{activity_id}/export_{'original' if sport in ['Swim', 'Other'] else 'tcx'}"
url = f"https://www.strava.com/activities/{activity_id}/export_{
'original' if sport in ['Swim', 'Other'] else 'tcx'}"
try:
webbrowser.open(url)
except Exception as err:
Expand Down Expand Up @@ -202,6 +204,17 @@ def ask_desired_language() -> str:
).ask()


def check_openai_key() -> None:
if not os.getenv("OPENAI_API_KEY"):
openai_key = questionary.password(
"Enter your OpenAI API key:"
).ask()
with open(".env", "w", encoding="utf-8") as env_file:
env_file.write(f"OPENAI_API_KEY={openai_key}")
load_dotenv()
logger.info("OpenAI API key loaded successfully.")


def perform_llm_analysis(data: TCXReader, sport: str, plan: str, language: str) -> str:
dataframe = preprocess_trackpoints_data(data)

Expand Down
34 changes: 30 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
perform_llm_analysis,
preprocess_trackpoints_data,
run_euclidean_dist_deletion,
remove_null_columns
remove_null_columns,
check_openai_key
)


Expand Down Expand Up @@ -146,6 +147,7 @@ def test_indent_xml_file_error(self):

self.assertTrue(mock_parse_string.called)

@patch('src.main.check_openai_key')
@patch('src.main.get_latest_download')
@patch('src.main.ask_sport')
@patch('src.main.ask_file_location')
Expand All @@ -155,11 +157,12 @@ def test_indent_xml_file_error(self):
@patch('src.main.validate_tcx_file')
@patch('src.main.indent_xml_file')
def test_main(self, mock_indent, mock_validate, mock_format, mock_download, mock_ask_id,
mock_ask_location, mock_ask_sport, mock_latest_download):
mock_ask_location, mock_ask_sport, mock_latest_download, mock_openai_key):
mock_ask_sport.return_value = "Swim"
mock_ask_location.return_value = "Download"
mock_ask_id.return_value = "12345"
mock_latest_download.return_value = "assets/swim.tcx"
mock_openai_key.return_value = None

main()

Expand All @@ -172,6 +175,7 @@ def test_main(self, mock_indent, mock_validate, mock_format, mock_download, mock
mock_validate.assert_not_called()
mock_indent.assert_called_once_with("assets/swim.tcx")

@patch('src.main.check_openai_key')
@patch('src.main.ask_sport')
@patch('src.main.ask_file_location')
@patch('src.main.ask_activity_id')
Expand All @@ -181,7 +185,8 @@ def test_main(self, mock_indent, mock_validate, mock_format, mock_download, mock
@patch('src.main.validate_tcx_file')
@patch('src.main.indent_xml_file')
def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_latest_download, mock_download,
mock_ask_id, mock_ask_location, mock_ask_sport):
mock_ask_id, mock_ask_location, mock_ask_sport, mock_openai_key):
mock_openai_key.return_value = None
mock_ask_sport.return_value = "InvalidSport"
mock_ask_location.return_value = "Download"
mock_ask_id.return_value = "12345"
Expand All @@ -199,6 +204,7 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
mock_validate.assert_not_called()
mock_indent.assert_not_called()

@patch('src.main.check_openai_key')
@patch('src.main.ask_desired_language')
@patch('src.main.ask_training_plan')
@patch('src.main.perform_llm_analysis')
Expand All @@ -213,7 +219,7 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
@patch('src.main.indent_xml_file')
def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask_path, mock_download,
mock_ask_id, mock_ask_location, mock_ask_sport, mock_llm_analysis, mock_perform_llm,
mock_training_plan, mock_language):
mock_training_plan, mock_language, mock_openai_key):
mock_ask_sport.return_value = "Bike"
mock_ask_location.return_value = "Local"
mock_ask_path.return_value = "assets/bike.tcx"
Expand All @@ -222,6 +228,7 @@ def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask
mock_perform_llm.return_value = "Training Plan"
mock_training_plan.return_value = ""
mock_language.return_value = "Portuguese"
mock_openai_key.return_value = None

main()

Expand Down Expand Up @@ -329,6 +336,25 @@ def test_ask_llm_analysis(self):
)
self.assertTrue(result)

def test_check_openai_api_key(self):
with patch('src.main.os.getenv') as mock_getenv:
mock_getenv.return_value = "API_KEY"
check_openai_key()
self.assertTrue(mock_getenv.called)
self.assertEqual(os.getenv("OPENAI_API_KEY"), "API_KEY")

@patch('src.main.os.getenv')
@patch('src.main.questionary.password')
@patch('builtins.open', new_callable=unittest.mock.mock_open)
def test_check_openai_api_key_empty(self, mock_open, mock_text, mock_getenv):
mock_text.return_value.ask.return_value = "API_KEY"
mock_getenv.return_value = None
mock_open.return_value.write.return_value = "API_KEY"
check_openai_key()
mock_text.assert_called_once_with('Enter your OpenAI API key:')
self.assertTrue(mock_open.called)
self.assertTrue(mock_text.called)

@patch('src.main.ChatOpenAI')
def test_perform_llm_analysis(self, mock_chat):
mock_invoke = mock_chat.return_value.invoke.return_value
Expand Down

0 comments on commit 91791e5

Please sign in to comment.