Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose EM in command line #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cherryml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "v0.0.8"
__version__ = "v0.0.9"

from cherryml._cherryml_public_api import cherryml_public_api
from cherryml.counting import count_co_transitions, count_transitions
Expand Down
21 changes: 21 additions & 0 deletions cherryml/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,27 @@ def none_or_value(value):
"increased statistical efficiency at essentially no runtime cost.",
)

# Experimental args to use EM, NOT exposed to users
parser.add_argument(
"--_em_backend",
type=str,
required=False,
default="",
help="Experimental argument, DO NOT USE. Whether to use EM instead "
"of CherryML. Can be either empty string '' (for default behaviour of "
"using CherryML), 'xrate' or 'historian'.",
)
parser.add_argument(
"--_extra_em_command_line_args",
type=str,
required=False,
default="-log 6 -f 3 -mi 0.000001",
help="Experimental argument, DO NOT USE. Extra EM command line "
"arguments. Defaults to '-log 6 -f 3 -mi 0.000001' for use with XRATE."
" Setting a larger -mi will result in faster runtime but less accurate"
" estimates."
)

# Functionality not currently exposed:
# parser.add_argument("--do_adam")
# parser.add_argument("--cpp_counting_command_line_prefix")
Expand Down
92 changes: 64 additions & 28 deletions cherryml/_cherryml_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cherryml.estimation_end_to_end import (
coevolution_end_to_end_with_cherryml_optimizer,
lg_end_to_end_with_cherryml_optimizer,
lg_end_to_end_with_em_optimizer,
)
from cherryml.io import read_rate_matrix, write_rate_matrix
from cherryml.markov_chain import get_lg_path
Expand Down Expand Up @@ -65,6 +66,9 @@ def cherryml_public_api(
use_maximal_matching: bool = True,
families: Optional[List[str]] = None,
tree_estimator_name: str = "FastTree",
# Experimental args to use EM, NOT exposed to users
_em_backend: str = "",
_extra_em_command_line_args: str = "-log 6 -f 3 -mi 0.000001",
) -> str:
"""
CherryML method applied to the LG model and the co-evolution model.
Expand Down Expand Up @@ -165,34 +169,66 @@ def cherryml_public_api(
raise ValueError(f"Unknown tree_estimator_name: {tree_estimator_name}")

if model_name == "LG":
outputs = lg_end_to_end_with_cherryml_optimizer(
msa_dir=msa_dir,
families=families,
tree_estimator=partial(
tree_estimator,
num_rate_categories=num_rate_categories,
),
initial_tree_estimator_rate_matrix_path=initial_tree_estimator_rate_matrix_path, # noqa
num_iterations=num_iterations,
quantization_grid_center=quantization_grid_center,
quantization_grid_step=quantization_grid_step,
quantization_grid_num_steps=quantization_grid_num_steps,
use_cpp_counting_implementation=use_cpp_counting_implementation,
optimizer_device=optimizer_device,
learning_rate=learning_rate,
num_epochs=num_epochs,
do_adam=do_adam,
edge_or_cherry=cherryml_type,
cpp_counting_command_line_prefix=cpp_counting_command_line_prefix,
cpp_counting_command_line_suffix=cpp_counting_command_line_suffix,
num_processes_tree_estimation=num_processes_tree_estimation,
num_processes_counting=num_processes_counting,
num_processes_optimization=num_processes_optimization,
optimizer_initialization=optimizer_initialization,
sites_subset_dir=sites_subset_dir,
tree_dir=tree_dir,
site_rates_dir=site_rates_dir,
)
if _em_backend != "":
if _em_backend not in ["xrate", "historian"]:
raise ValueError(
"_em_backend may only be 'xrate' or 'historian' (or the "
"empty string '' for default CherryML behaviour). You "
f"provided: '{_em_backend}'."
)
outputs = lg_end_to_end_with_em_optimizer(
msa_dir=msa_dir,
families=families,
tree_estimator=partial(
tree_estimator,
num_rate_categories=num_rate_categories,
),
initial_tree_estimator_rate_matrix_path=initial_tree_estimator_rate_matrix_path, # noqa
num_iterations=num_iterations,
quantization_grid_center=quantization_grid_center,
quantization_grid_step=quantization_grid_step,
quantization_grid_num_steps=quantization_grid_num_steps,
use_cpp_counting_implementation=use_cpp_counting_implementation, # noqa
extra_em_command_line_args=_extra_em_command_line_args,
cpp_counting_command_line_prefix=cpp_counting_command_line_prefix, # noqa
cpp_counting_command_line_suffix=cpp_counting_command_line_suffix, # noqa
num_processes_tree_estimation=num_processes_tree_estimation,
num_processes_counting=num_processes_counting,
num_processes_optimization=num_processes_optimization,
optimizer_initialization=optimizer_initialization,
sites_subset_dir=sites_subset_dir,
em_backend=_em_backend,
)
else:
# Just standard CherryML (the only thing we expose to users)
outputs = lg_end_to_end_with_cherryml_optimizer(
msa_dir=msa_dir,
families=families,
tree_estimator=partial(
tree_estimator,
num_rate_categories=num_rate_categories,
),
initial_tree_estimator_rate_matrix_path=initial_tree_estimator_rate_matrix_path, # noqa
num_iterations=num_iterations,
quantization_grid_center=quantization_grid_center,
quantization_grid_step=quantization_grid_step,
quantization_grid_num_steps=quantization_grid_num_steps,
use_cpp_counting_implementation=use_cpp_counting_implementation, # noqa
optimizer_device=optimizer_device,
learning_rate=learning_rate,
num_epochs=num_epochs,
do_adam=do_adam,
edge_or_cherry=cherryml_type,
cpp_counting_command_line_prefix=cpp_counting_command_line_prefix, # noqa
cpp_counting_command_line_suffix=cpp_counting_command_line_suffix, # noqa
num_processes_tree_estimation=num_processes_tree_estimation,
num_processes_counting=num_processes_counting,
num_processes_optimization=num_processes_optimization,
optimizer_initialization=optimizer_initialization,
sites_subset_dir=sites_subset_dir,
tree_dir=tree_dir,
site_rates_dir=site_rates_dir,
)
learned_rate_matrix = read_rate_matrix(
os.path.join(outputs["learned_rate_matrix_path"])
)
Expand Down