Skip to content

Commit 514e480

Browse files
committed
WIP: Add NmtFileService and reading/writing of per-text cached analysis
1 parent 5efc056 commit 514e480

File tree

5 files changed

+80
-25
lines changed

5 files changed

+80
-25
lines changed

machine/jobs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .clearml_shared_file_service import ClearMLSharedFileService
22
from .local_shared_file_service import LocalSharedFileService
33
from .nmt_engine_build_job import NmtEngineBuildJob
4+
from .nmt_file_service import CachedAnalysisInfo, NmtFileService
45
from .nmt_model_factory import NmtModelFactory
56
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
67
from .smt_engine_build_job import SmtEngineBuildJob
@@ -13,9 +14,11 @@
1314
from .word_alignment_model_factory import WordAlignmentModelFactory
1415

1516
__all__ = [
17+
"CachedAnalysisInfo",
1618
"ClearMLSharedFileService",
1719
"LocalSharedFileService",
1820
"NmtEngineBuildJob",
21+
"NmtFileService",
1922
"NmtModelFactory",
2023
"DictToJsonWriter",
2124
"SharedFileServiceBase",

machine/jobs/build_nmt_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from .build_clearml_helper import report_clearml_progress, update_settings
1010
from .config import SETTINGS
1111
from .nmt_engine_build_job import NmtEngineBuildJob
12+
from .nmt_file_service import NmtFileService
1213
from .nmt_model_factory import NmtModelFactory
1314
from .shared_file_service_factory import SharedFileServiceType
14-
from .translation_file_service import TranslationFileService
1515

1616
# Setup logging
1717
logging.basicConfig(
@@ -47,7 +47,7 @@ def clearml_progress(status: ProgressStatus) -> None:
4747
logger.info("NMT Engine Build Job started")
4848
update_settings(SETTINGS, args, task, logger)
4949

50-
translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS)
50+
translation_file_service = NmtFileService(SharedFileServiceType.CLEARML, SETTINGS)
5151
nmt_model_factory: NmtModelFactory
5252
if SETTINGS.model_type == "huggingface":
5353
from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory

machine/jobs/nmt_engine_build_job.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,36 @@
1010
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
1111
from ..utils.progress_status import ProgressStatus
1212
from .eflomal_aligner import EflomalAligner, is_eflomal_available, tokenize
13+
from .nmt_file_service import NmtFileService
1314
from .nmt_model_factory import NmtModelFactory
1415
from .translation_engine_build_job import TranslationEngineBuildJob
15-
from .translation_file_service import PretranslationInfo, TranslationFileService
16+
from .translation_file_service import PretranslationInfo
1617

1718
logger = logging.getLogger(__name__)
1819

1920

2021
class NmtEngineBuildJob(TranslationEngineBuildJob):
21-
def __init__(
22-
self, config: Any, nmt_model_factory: NmtModelFactory, translation_file_service: TranslationFileService
23-
) -> None:
22+
def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, nmt_file_service: NmtFileService) -> None:
23+
self._nmt_file_service = nmt_file_service
2424
self._nmt_model_factory = nmt_model_factory
2525
self._nmt_model_factory.init()
26-
super().__init__(config, translation_file_service)
26+
super().__init__(config, nmt_file_service)
2727

2828
def _get_progress_reporter(
2929
self, progress: Optional[Callable[[ProgressStatus], None]], corpus_size: int
3030
) -> PhasedProgressReporter:
3131
if corpus_size > 0:
3232
if self._config.align_pretranslations:
3333
phases = [
34-
Phase(message="Training NMT model", percentage=0.8, stage="train"),
34+
Phase(message="Detecting quotation conventions", percentage=0.01),
35+
Phase(message="Training NMT model", percentage=0.79, stage="train"),
3536
Phase(message="Pretranslating segments", percentage=0.1, stage="inference"),
3637
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
3738
]
3839
else:
3940
phases = [
40-
Phase(message="Training NMT model", percentage=0.9, stage="train"),
41+
Phase(message="Detecting quotation conventions", percentage=0.01),
42+
Phase(message="Training NMT model", percentage=0.89, stage="train"),
4143
Phase(message="Pretranslating segments", percentage=0.1, stage="inference"),
4244
]
4345
else:
@@ -82,6 +84,10 @@ def _train_model(
8284
if check_canceled is not None:
8385
check_canceled()
8486

87+
logger.info("Detecting quotation conventions")
88+
progress_reporter.start_next_phase()
89+
# TODO: Detect quotation conventions
90+
8591
logger.info("Training NMT model")
8692
with (
8793
progress_reporter.start_next_phase() as phase_progress,

machine/jobs/nmt_file_service.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from contextlib import contextmanager
2+
from typing import Generator, Iterator, TypedDict
3+
4+
import json_stream
5+
6+
from ..utils.context_managed_generator import ContextManagedGenerator
7+
from .shared_file_service_base import DictToJsonWriter
8+
from .translation_file_service import TranslationFileService
9+
10+
11+
class CachedAnalysisInfo(TypedDict):
12+
corpusId: str # noqa: N815
13+
textId: str # noqa: N815
14+
quoteConvention: str # noqa: N815
15+
16+
17+
SOURCE_CACHED_ANALYSIS_FILENAME = "cached_analysis.src.json"
18+
TARGET_CACHED_ANALYSIS_FILENAME = "cached_analysis.trg.json"
19+
20+
21+
class NmtFileService(TranslationFileService):
22+
23+
def _get_cached_analysis(self, file_name: str) -> ContextManagedGenerator[CachedAnalysisInfo, None, None]:
24+
src_cached_analysis_path = self.shared_file_service.download_file(
25+
f"{self.shared_file_service.build_path}/{file_name}"
26+
)
27+
28+
def generator() -> Generator[CachedAnalysisInfo, None, None]:
29+
with src_cached_analysis_path.open("r", encoding="utf-8-sig") as file:
30+
for pi in json_stream.load(file):
31+
yield CachedAnalysisInfo(
32+
corpusId=pi["corpusId"],
33+
textId=pi["textId"],
34+
quoteConvention=pi["quoteConvention"],
35+
)
36+
37+
return ContextManagedGenerator(generator())
38+
39+
def get_source_cached_analysis(self) -> ContextManagedGenerator[CachedAnalysisInfo, None, None]:
40+
return self._get_cached_analysis(SOURCE_CACHED_ANALYSIS_FILENAME)
41+
42+
def get_target_cached_analysis(self) -> ContextManagedGenerator[CachedAnalysisInfo, None, None]:
43+
return self._get_cached_analysis(TARGET_CACHED_ANALYSIS_FILENAME)
44+
45+
@contextmanager
46+
def open_source_cached_analysis_writer(self) -> Iterator[DictToJsonWriter]:
47+
return self.shared_file_service.open_target_writer(SOURCE_CACHED_ANALYSIS_FILENAME)
48+
49+
@contextmanager
50+
def open_target_cached_analysis_writer(self) -> Iterator[DictToJsonWriter]:
51+
return self.shared_file_service.open_target_writer(TARGET_CACHED_ANALYSIS_FILENAME)

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@
1010

1111
from machine.annotations import Range
1212
from machine.corpora import DictionaryTextCorpus
13-
from machine.jobs import (
14-
DictToJsonWriter,
15-
NmtEngineBuildJob,
16-
NmtModelFactory,
17-
PretranslationInfo,
18-
TranslationFileService,
19-
)
13+
from machine.jobs import DictToJsonWriter, NmtEngineBuildJob, NmtFileService, NmtModelFactory, PretranslationInfo
2014
from machine.jobs.eflomal_aligner import is_eflomal_available
2115
from machine.translation import (
2216
Phrase,
@@ -54,7 +48,8 @@ def test_run(decoy: Decoy) -> None:
5448
assert pretranslations[0]["sourceTokens"] == []
5549
assert pretranslations[0]["translationTokens"] == []
5650
assert len(pretranslations[0]["alignment"]) == 0
57-
decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
51+
52+
decoy.verify(env.nmt_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
5853

5954

6055
def test_cancel(decoy: Decoy) -> None:
@@ -116,12 +111,12 @@ def __init__(self, decoy: Decoy) -> None:
116111
decoy.when(self.nmt_model_factory.create_engine()).then_return(self.engine)
117112
decoy.when(self.nmt_model_factory.save_model()).then_return(Path("model.tar.gz"))
118113

119-
self.translation_file_service = decoy.mock(cls=TranslationFileService)
120-
decoy.when(self.translation_file_service.create_source_corpus()).then_return(DictionaryTextCorpus())
121-
decoy.when(self.translation_file_service.create_target_corpus()).then_return(DictionaryTextCorpus())
122-
decoy.when(self.translation_file_service.exists_source_corpus()).then_return(True)
123-
decoy.when(self.translation_file_service.exists_target_corpus()).then_return(True)
124-
decoy.when(self.translation_file_service.get_source_pretranslations()).then_do(
114+
self.nmt_file_service = decoy.mock(cls=NmtFileService)
115+
decoy.when(self.nmt_file_service.create_source_corpus()).then_return(DictionaryTextCorpus())
116+
decoy.when(self.nmt_file_service.create_target_corpus()).then_return(DictionaryTextCorpus())
117+
decoy.when(self.nmt_file_service.exists_source_corpus()).then_return(True)
118+
decoy.when(self.nmt_file_service.exists_target_corpus()).then_return(True)
119+
decoy.when(self.nmt_file_service.get_source_pretranslations()).then_do(
125120
lambda: ContextManagedGenerator(
126121
(
127122
pi
@@ -150,7 +145,7 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJ
150145
file.write("\n]\n")
151146
env.target_pretranslations = file.getvalue()
152147

153-
decoy.when(self.translation_file_service.open_target_pretranslation_writer()).then_do(
148+
decoy.when(self.nmt_file_service.open_target_pretranslation_writer()).then_do(
154149
lambda: open_target_pretranslation_writer(self)
155150
)
156151

@@ -165,7 +160,7 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJ
165160
}
166161
),
167162
self.nmt_model_factory,
168-
self.translation_file_service,
163+
self.nmt_file_service,
169164
)
170165

171166

0 commit comments

Comments
 (0)