Skip to content

Commit 7a13a26

Browse files
committed
fix: Fix types around http calls. Properly accept both str and requests.Response, defensive programming around the json we get back from the API
1 parent 9bc2344 commit 7a13a26

File tree

4 files changed

+88
-31
lines changed

4 files changed

+88
-31
lines changed

deepl/api_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class Usage:
197197
"""
198198

199199
class Detail:
200-
def __init__(self, json: dict, prefix: str):
200+
def __init__(self, json: Optional[dict], prefix: str):
201201
self._count = util.get_int_safe(json, f"{prefix}_count")
202202
self._limit = util.get_int_safe(json, f"{prefix}_limit")
203203

@@ -238,7 +238,7 @@ def limit_exceeded(self) -> bool:
238238
def __str__(self) -> str:
239239
return f"{self.count} of {self.limit}" if self.valid else "Unknown"
240240

241-
def __init__(self, json: dict):
241+
def __init__(self, json: Optional[dict]):
242242
self._character = self.Detail(json, "character")
243243
self._document = self.Detail(json, "document")
244244
self._team_document = self.Detail(json, "team_document")

deepl/http_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def request(
163163
request = self._prepare_request(
164164
method, url, data, json, headers, **kwargs
165165
)
166-
return self._internal_request(request, stream, stream=stream)
166+
return self._internal_request(request, stream)
167167

168168
def _internal_request(
169169
self,

deepl/translator.py

+83-26
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _api_call(
117117
stream: bool = False,
118118
headers: Optional[dict] = None,
119119
**kwargs,
120-
) -> Tuple[int, Union[str, requests.Response], dict]:
120+
) -> Tuple[int, Union[str, requests.Response], Any]:
121121
"""
122122
Makes a request to the API, and returns response as status code,
123123
content and JSON object.
@@ -167,14 +167,14 @@ def _raise_for_status(
167167
self,
168168
status_code: int,
169169
content: Union[str, requests.Response],
170-
json: Optional[dict],
170+
json: Any,
171171
glossary: bool = False,
172172
downloading_document: bool = False,
173173
):
174174
message = ""
175-
if json is not None and "message" in json:
175+
if json is not None and isinstance(json, dict) and "message" in json:
176176
message += ", message: " + json["message"]
177-
if json is not None and "detail" in json:
177+
if json is not None and isinstance(json, dict) and "detail" in json:
178178
message += ", detail: " + json["detail"]
179179

180180
if 200 <= status_code < 400:
@@ -296,7 +296,7 @@ def _create_glossary(
296296
source_lang: Union[str, Language],
297297
target_lang: Union[str, Language],
298298
entries_format: str,
299-
entries: str,
299+
entries: Union[str, bytes],
300300
) -> GlossaryInfo:
301301
# glossaries are only supported for base language types
302302
source_lang = Language.remove_regional_variant(source_lang)
@@ -437,11 +437,19 @@ def join_tags(tag_argument: Union[str, Iterable[str]]) -> List[str]:
437437

438438
self._raise_for_status(status, content, json)
439439

440-
translations = json.get("translations", [])
440+
translations = (
441+
json.get("translations", [])
442+
if (json and isinstance(json, dict))
443+
else []
444+
)
441445
output = []
442446
for translation in translations:
443-
text = translation.get("text")
444-
lang = translation.get("detected_source_language")
447+
text = translation.get("text", "") if translation else ""
448+
lang = (
449+
translation.get("detected_source_language", "")
450+
if translation
451+
else ""
452+
)
445453
output.append(TextResult(text, detected_source_lang=lang))
446454

447455
return output if multi_input else output[0]
@@ -633,6 +641,7 @@ def translate_document_upload(
633641
source_lang, target_lang, formality, glossary
634642
)
635643

644+
files: Dict[str, Any] = {}
636645
if isinstance(input_document, (str, bytes)):
637646
if filename is None:
638647
raise ValueError(
@@ -647,7 +656,11 @@ def translate_document_upload(
647656
)
648657
self._raise_for_status(status, content, json)
649658

650-
return DocumentHandle(json["document_id"], json["document_key"])
659+
if not json:
660+
json = {}
661+
return DocumentHandle(
662+
json.get("document_id", ""), json.get("document_key", "")
663+
)
651664

652665
def translate_document_get_status(
653666
self, handle: DocumentHandle
@@ -657,19 +670,42 @@ def translate_document_get_status(
657670
658671
:param handle: DocumentHandle to the request to check.
659672
:return: DocumentStatus containing the request status.
673+
674+
:raises DocumentTranslationException: If an error occurs during
675+
querying the document, the exception includes the document handle.
660676
"""
661677

662678
data = {"document_key": handle.document_key}
663679
url = f"v2/document/{handle.document_id}"
664680

665-
status, content, json = self._api_call(url, json=data)
681+
status_code, content, json = self._api_call(url, json=data)
666682

667-
self._raise_for_status(status, content, json)
683+
self._raise_for_status(status_code, content, json)
668684

669-
status = json["status"]
670-
seconds_remaining = json.get("seconds_remaining", None)
671-
billed_characters = json.get("billed_characters", None)
672-
error_message = json.get("error_message", None)
685+
status = (
686+
json.get("status", None)
687+
if (json and isinstance(json, dict))
688+
else None
689+
)
690+
if not status:
691+
raise DocumentTranslationException(
692+
"Querying document status gave an empty response", handle
693+
)
694+
seconds_remaining = (
695+
json.get("seconds_remaining", None)
696+
if (json and isinstance(json, dict))
697+
else None
698+
)
699+
billed_characters = (
700+
json.get("billed_characters", None)
701+
if (json and isinstance(json, dict))
702+
else None
703+
)
704+
error_message = (
705+
json.get("error_message", None)
706+
if (json and isinstance(json, dict))
707+
else None
708+
)
673709
return DocumentStatus(
674710
status, seconds_remaining, billed_characters, error_message
675711
)
@@ -726,17 +762,17 @@ def translate_document_download(
726762
status_code, response, json = self._api_call(
727763
url, json=data, stream=True
728764
)
765+
# TODO: once we drop py3.6 support, replace this with @overload
766+
# annotations in `_api_call` and chained private functions.
767+
# See for example https://stackoverflow.com/a/74070166/4926599
768+
assert isinstance(response, requests.Response)
729769

730770
self._raise_for_status(
731771
status_code, "<file>", json, downloading_document=True
732772
)
733773

734774
if output_file:
735-
chunks = (
736-
response.iter_content(chunk_size=chunk_size)
737-
if isinstance(response, requests.Response)
738-
else [response]
739-
)
775+
chunks = response.iter_content(chunk_size=chunk_size)
740776
for chunk in chunks:
741777
output_file.write(chunk)
742778
return None
@@ -753,12 +789,13 @@ def get_source_languages(self, skip_cache=False) -> List[Language]:
753789
"""
754790
status, content, json = self._api_call("v2/languages", method="GET")
755791
self._raise_for_status(status, content, json)
792+
languages = json if (json and isinstance(json, list)) else []
756793
return [
757794
Language(
758795
language["language"],
759796
language["name"],
760797
)
761-
for language in json
798+
for language in languages
762799
]
763800

764801
def get_target_languages(self, skip_cache=False) -> List[Language]:
@@ -774,13 +811,14 @@ def get_target_languages(self, skip_cache=False) -> List[Language]:
774811
"v2/languages", method="GET", data=data
775812
)
776813
self._raise_for_status(status, content, json)
814+
languages = json if (json and isinstance(json, list)) else []
777815
return [
778816
Language(
779817
language["language"],
780818
language["name"],
781819
language.get("supports_formality", None),
782820
)
783-
for language in json
821+
for language in languages
784822
]
785823

786824
def get_glossary_languages(self) -> List[GlossaryLanguagePair]:
@@ -791,11 +829,16 @@ def get_glossary_languages(self) -> List[GlossaryLanguagePair]:
791829

792830
self._raise_for_status(status, content, json)
793831

832+
supported_languages = (
833+
json.get("supported_languages", [])
834+
if (json and isinstance(json, dict))
835+
else []
836+
)
794837
return [
795838
GlossaryLanguagePair(
796839
language_pair["source_lang"], language_pair["target_lang"]
797840
)
798-
for language_pair in json["supported_languages"]
841+
for language_pair in supported_languages
799842
]
800843

801844
def get_usage(self) -> Usage:
@@ -804,6 +847,8 @@ def get_usage(self) -> Usage:
804847

805848
self._raise_for_status(status, content, json)
806849

850+
if not isinstance(json, dict):
851+
json = {}
807852
return Usage(json)
808853

809854
def create_glossary(
@@ -888,6 +933,8 @@ def create_glossary_from_csv(
888933
csv_data if isinstance(csv_data, (str, bytes)) else csv_data.read()
889934
)
890935

936+
if not isinstance(entries, (bytes, str)):
937+
raise ValueError("Entries of the glossary are invalid")
891938
return self._create_glossary(
892939
name, source_lang, target_lang, "csv", entries
893940
)
@@ -913,9 +960,12 @@ def list_glossaries(self) -> List[GlossaryInfo]:
913960
"""
914961
status, content, json = self._api_call("v2/glossaries", method="GET")
915962
self._raise_for_status(status, content, json, glossary=True)
916-
return [
917-
GlossaryInfo.from_json(glossary) for glossary in json["glossaries"]
918-
]
963+
glossaries = (
964+
json.get("glossaries", [])
965+
if (json and isinstance(json, dict))
966+
else []
967+
)
968+
return [GlossaryInfo.from_json(glossary) for glossary in glossaries]
919969

920970
def get_glossary_entries(self, glossary: Union[str, GlossaryInfo]) -> dict:
921971
"""Retrieves the entries of the specified glossary and returns them as
@@ -925,6 +975,8 @@ def get_glossary_entries(self, glossary: Union[str, GlossaryInfo]) -> dict:
925975
:return: dictionary of glossary entries.
926976
:raises GlossaryNotFoundException: If no glossary with given ID is
927977
found.
978+
:raises DeepLException: If the glossary could not be retrieved
979+
in the right format.
928980
"""
929981
if isinstance(glossary, GlossaryInfo):
930982
glossary_id = glossary.glossary_id
@@ -937,6 +989,11 @@ def get_glossary_entries(self, glossary: Union[str, GlossaryInfo]) -> dict:
937989
headers={"Accept": "text/tab-separated-values"},
938990
)
939991
self._raise_for_status(status, content, json, glossary=True)
992+
if not isinstance(content, str):
993+
raise DeepLException(
994+
"Could not get the glossary content as a string",
995+
http_status_code=status,
996+
)
940997
return util.convert_tsv_to_dict(content)
941998

942999
def delete_glossary(self, glossary: Union[str, GlossaryInfo]) -> None:

deepl/util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def log_info(message, **kwargs):
4949
logger.info(text)
5050

5151

52-
def get_int_safe(d: dict, key: str) -> Optional[int]:
52+
def get_int_safe(d: Optional[dict], key: str) -> Optional[int]:
5353
"""Returns value in dictionary with given key as int, or None."""
5454
try:
55-
return int(d.get(key))
55+
return int(d.get(key) if d else None) # type: ignore[arg-type]
5656
except (TypeError, ValueError):
5757
return None
5858

0 commit comments

Comments
 (0)