Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: optionally output graph in pajek format #293

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub fn cluster(
similarity_column: String,
similarity_threshold: f64,
cluster_sizes: Option<String>,
output_graph: Option<String>,
) -> Result<()> {
let (graph, name_to_node) =
match build_graph(&pairwise_csv, &similarity_column, similarity_threshold) {
Expand Down Expand Up @@ -133,7 +134,7 @@ pub fn cluster(
*count += 1;
}

// write the sizes and counts
// Optionally, write the sizes and counts
if let Some(sizes_file) = cluster_sizes {
let mut cluster_size_file =
File::create(sizes_file).context("Failed to create cluster size file")?;
Expand All @@ -145,5 +146,29 @@ pub fn cluster(
}
}

// Optionally, write graph file in Pajek format. Note that '.net' is typical extension
if let Some(graph_file) = output_graph {
let mut out_graph_file =
File::create(graph_file).context("Failed to create graph output file")?;

// Write the network header
writeln!(out_graph_file, "*Vertices {}", graph.node_count())?;
for node in 0..graph.node_count() {
writeln!(out_graph_file, "{}", node + 1)?; // Pajek uses 1-based indexing
}
writeln!(out_graph_file, "*Edges")?;

// Write the edges
for edge in graph.edge_indices() {
let (source, target) = graph.edge_endpoints(edge).unwrap();
writeln!(
out_graph_file,
"{} {}",
source.index() + 1,
target.index() + 1
)?; // Pajek uses 1-based indexing
}
}

Ok(())
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,15 @@ fn do_cluster(
similarity_column: String,
similarity_threshold: f64,
cluster_sizes: Option<String>,
output_graph: Option<String>,
) -> anyhow::Result<u8> {
match cluster::cluster(
pairwise_csv,
output_clusters,
similarity_column,
similarity_threshold,
cluster_sizes,
output_graph,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand Down
13 changes: 10 additions & 3 deletions src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,9 @@ def __init__(self, p):
p.add_argument('-o', '--output', required=True,
help='output csv file for the clusters')
p.add_argument('--cluster-sizes', default=None,
help='output file for the cluster size histogram')
help='Optionally, output cluster size histogram to this file')
p.add_argument('--graph-output', default=None,
help='Optionally, output graph in Pajek format to this file (recommended file extension: ".net")')
p.add_argument('--similarity-column', type=str, default='average_containment_ani',
choices=['containment', 'max_containment', 'jaccard', 'average_containment_ani', 'max_containment_ani'],
help='column to use as similarity measure')
Expand All @@ -402,8 +404,13 @@ def main(self, args):
args.output,
args.similarity_column,
args.threshold,
args.cluster_sizes)
args.cluster_sizes,
args.graph_output)
if status == 0:
notify(f"...clustering is done! results in '{args.output}'")
notify(f" cluster counts in '{args.cluster_sizes}'")
if args.cluster_sizes:
notify(f" cluster counts in '{args.cluster_sizes}'")
if args.graph_output:
notify(f" Pajek-format graph output in '{args.graph_output}'")

return status
143 changes: 143 additions & 0 deletions src/python/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,146 @@ def test_bad_file(runtmp, capfd):
print(captured.err)

assert "Error: Failed to build graph" in captured.err


def test_cluster_ani_output_graph(runtmp):
pairwise_csv = get_test_data('cluster.pairwise.csv')
output = runtmp.output('clusters.csv')
sizes = runtmp.output('sizes.csv')
graph = runtmp.output('clustergraph.net')
threshold = '0.9'

runtmp.sourmash('scripts', 'cluster', pairwise_csv, '-o', output,
'--similarity-column', "average_containment_ani", "--cluster-sizes",
sizes, '--threshold', threshold, "--graph-output", graph)

assert os.path.exists(output)

# check cluster output
with open(output, mode='r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
rows = [row for row in reader]
assert reader.fieldnames == ['cluster','nodes']
assert len(rows) == 2, f"Expected 2 data rows but found {len(rows)}"
assert rows[0]['cluster'] == 'Component_1'
expected_node_sets = [
set("n1;n2;n3;n4;n5".split(';')),
set("n6;n7".split(';'))
]
for row in rows:
assert set(row['nodes'].split(';')) in expected_node_sets

# check cluster size histogram
with open(sizes, mode='r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
rows = [row for row in reader]
assert reader.fieldnames == ['cluster_size','count']
assert len(rows) == 2, f"Expected 2 data rows but found {len(rows)}"
rows_as_tuples = {tuple(row.values()) for row in rows}
expected = {('5', '1'), ('2', '1')}
assert rows_as_tuples == expected

# check graph output
expected_vertex_count = 7
expected_vertices = ['1', '2', '3', '4', '5', '6', '7']
expected_edges = [('1', '2'), ('1', '3'), ('2', '3'), ('2', '4'), ('3', '4'), ('4', '5'), ('6', '7')]
assert os.path.exists(graph)
with open(graph, 'r', newline='') as pajek_graph:
reader = csv.reader(pajek_graph, delimiter=' ')

found_vertices = []
found_edges = []

for row in reader:
if not row:
continue

if row[0] == "*Vertices":
section = "Vertices"
continue

if row[0] == "*Edges":
section = "Edges"
continue

if section == "Vertices":
found_vertices.append(row[0])

if section == "Edges":
found_edges.append((row[0], row[1]))

# Check found vertices and edges against expected values
assert len(found_vertices) == expected_vertex_count
assert found_vertices == expected_vertices, f"Vertices dont match: {found_vertices}"
assert found_edges == expected_edges, f"Edges dont match: {found_edges}"


def test_cluster_ani_pairwise_graph_output(runtmp):
pairwise_csv = runtmp.output('pairwise.csv')
output = runtmp.output('clusters.csv')
graph = runtmp.output('clustergraph.net')
cluster_threshold = '0.90'

query_list = runtmp.output('query.txt')
sig2 = get_test_data('2.fa.sig.gz')
sig47 = get_test_data('47.fa.sig.gz')
sig63 = get_test_data('63.fa.sig.gz')

make_file_list(query_list, [sig2, sig47, sig63])

runtmp.sourmash('scripts', 'pairwise', query_list,
'-o', pairwise_csv, "-t", "-0.1", "--ani")

assert os.path.exists(pairwise_csv)

runtmp.sourmash('scripts', 'cluster', pairwise_csv, '-o', output,
'--similarity-column', "average_containment_ani", "--graph-output",
graph, '--threshold', cluster_threshold)

assert os.path.exists(output)

# check cluster output
with open(output, mode='r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
rows = [row for row in reader]
assert reader.fieldnames == ['cluster','nodes']
print(rows)
assert len(rows) == 2, f"Expected 2 data rows but found {len(rows)}"
assert rows[0]['cluster'] == 'Component_1'
expected_node_sets = [set("NC_009661.1;NC_011665.1".split(';')), set("CP001071.1".split(';'))]
for row in rows:
assert set(row['nodes'].split(';')) in expected_node_sets

# check graph output
expected_vertex_count = 3
expected_vertices = ['1', '2', '3']
n_expected_edges = 1
assert os.path.exists(graph)
with open(graph, 'r', newline='') as pajek_graph:
reader = csv.reader(pajek_graph, delimiter=' ')

found_vertices = []
found_edges = []

for row in reader:
if not row:
continue

if row[0] == "*Vertices":
section = "Vertices"
continue

if row[0] == "*Edges":
section = "Edges"
continue

if section == "Vertices":
found_vertices.append(row[0])

if section == "Edges":
found_edges.append((row[0], row[1]))

# Check found vertices and edges against expected values
assert len(found_vertices) == expected_vertex_count
assert found_vertices == expected_vertices, f"Vertices dont match: {found_vertices}"
assert len(found_edges) == n_expected_edges, f"Edge count doesnt match: found edges: {found_edges}"
Loading