Skip to content

Commit

Permalink
updated plot_tools with path registry
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Feb 23, 2024
1 parent 456e012 commit 7b613ad
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 72 deletions.
150 changes: 92 additions & 58 deletions mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,88 @@
from mdagent.utils import PathRegistry


def process_csv(file_name):
with open(file_name, "r") as f:
reader = csv.DictReader(f)
headers = reader.fieldnames
data = list(reader)

matched_headers = [
(i, header)
for i, header in enumerate(headers)
if re.search(r"(step|time)", header, re.IGNORECASE)
]

return data, headers, matched_headers


def plot_data(data, headers, matched_headers):
# Get the first matched header
if matched_headers:
time_or_step = matched_headers[0][1]
xlab = "step" if "step" in time_or_step.lower() else "time"
else:
print("No 'step' or 'time' headers found.")
return

failed_headers = []
created_plots = []
for header in headers:
if header != time_or_step:
try:
x = [float(row[time_or_step]) for row in data]
y = [float(row[header]) for row in data]

header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{xlab} vs {header_lab}")
plt.savefig(plot_name)
plt.close()

created_plots.append(plot_name)
except ValueError:
failed_headers.append(header)

if len(failed_headers) == len(headers) - 1: # -1 to account for time_or_step header
raise Exception("All plots failed due to non-numeric data.")

return ", ".join(created_plots)
class PlottingTools:
def __init__(
self,
path_registry,
):
self.path_registry = path_registry
self.data = None
self.headers = None
self.matched_headers = None
self.file_name = None
self.file_path = None

def _find_file(self, file_name: str) -> None:
self.file_name = file_name
self.file_path = self.path_registry.get_mapped_path(file_name)
if not self.file_path:
raise FileNotFoundError("File not found.")
return None

def process_csv(self) -> None:
with open(self.file_path, "r") as f:
reader = csv.DictReader(f)
self.headers = reader.fieldnames if reader.fieldnames is not None else []
self.data = list(reader)

self.matched_headers = [
(i, header)
for i, header in enumerate(self.headers)
if re.search(r"(step|time)", header, re.IGNORECASE)
]

if not self.matched_headers or not self.headers or not self.data:
raise ValueError("File could not be processed.")
return None

def plot_data(self) -> str:
if self.matched_headers:
time_or_step = self.matched_headers[0][1]
xlab = "step" if "step" in time_or_step.lower() else "time"
else:
return "No 'step' or 'time' headers found."

failed_headers = []
created_plots = []
for header in self.headers:
if header != time_or_step:
try:
x = [float(row[time_or_step]) for row in self.data]
y = [float(row[header]) for row in self.data]

header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{self.file_name}_{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{self.file_name}_{xlab} vs {header_lab}")
plt.savefig(plot_name)
self.path_registry.map_path(
plot_name,
plot_name,
(
"Post Simulation Figure for "
"{self.file_name} - {header_lab} vs {xlab}"
),
)
plt.close()

created_plots.append(plot_name)
except ValueError:
failed_headers.append(header)

if (
len(failed_headers) == len(self.headers) - 1
): # -1 to account for time_or_step header
raise Exception("All plots failed due to non-numeric data.")

return ", ".join(created_plots)


class SimulationOutputFigures(BaseTool):
Expand All @@ -76,17 +104,23 @@ class SimulationOutputFigures(BaseTool):

path_registry: Optional[PathRegistry]

def _run(self, file_path: str) -> str:
def __init__(self, path_registry: Optional[PathRegistry] = None):
super().__init__()
self.path_registry = path_registry

def _run(self, file_name: str) -> str:
"""use the tool."""
try:
data, headers, matched_headers = process_csv(file_path)
plot_result = plot_data(data, headers, matched_headers)
plotting_tools = PlottingTools(self.path_registry)
plotting_tools._find_file(file_name)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
return "Figures created: " + plot_result
else:
return "No figures created."
except ValueError:
return "No timestep data found in csv file."
return "File could not be processed."
except FileNotFoundError:
return "Issue with CSV file, file not found."
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def make_all_tools(
RMSDCalculator(),
SetUpandRunFunction(path_registry=path_instance),
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm),
SimulationOutputFigures(),
SimulationOutputFigures(path_registry=path_instance),
]
if subagent_settings is None:
subagent_settings = SubAgentSettings(path_registry=path_instance)
Expand Down
42 changes: 29 additions & 13 deletions tests/test_fxns.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
VisFunctions,
get_pdb,
)
from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv
from mdagent.tools.base_tools.analysis_tools.plot_tools import PlottingTools
from mdagent.tools.base_tools.preprocess_tools.pdb_tools import MolPDB, PackMolTool
from mdagent.utils import FileType, PathRegistry

Expand Down Expand Up @@ -69,12 +69,17 @@ def get_registry():
return PathRegistry()


@pytest.fixture
def plotting_tools(get_registry):
return PlottingTools(get_registry)


@pytest.fixture
def packmol(get_registry):
return PackMolTool(get_registry)


def test_process_csv():
def test_process_csv(plotting_tools):
mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25"
mock_reader = MagicMock()
mock_reader.fieldnames = ["Time", "Value1", "Value2"]
Expand All @@ -84,19 +89,23 @@ def test_process_csv():
{"Time": "2", "Value1": "15", "Value2": "25"},
]
)

plotting_tools.file_path = "mock_file.csv"
plotting_tools.file_name = "mock_file.csv"
with patch("builtins.open", mock_open(read_data=mock_csv_content)):
with patch("csv.DictReader", return_value=mock_reader):
data, headers, matched_headers = process_csv("mock_file.csv")

assert headers == ["Time", "Value1", "Value2"]
assert len(matched_headers) == 1
assert matched_headers[0][1] == "Time"
assert len(data) == 2
assert data[0]["Time"] == "1" and data[0]["Value1"] == "10"
plotting_tools.process_csv()

assert plotting_tools.headers == ["Time", "Value1", "Value2"]
assert len(plotting_tools.matched_headers) == 1
assert plotting_tools.matched_headers[0][1] == "Time"
assert len(plotting_tools.data) == 2
assert (
plotting_tools.data[0]["Time"] == "1"
and plotting_tools.data[0]["Value1"] == "10"
)


def test_plot_data():
def test_plot_data(plotting_tools):
# Test successful plot generation
data_success = [
{"Time": "1", "Value1": "10", "Value2": "20"},
Expand All @@ -112,7 +121,10 @@ def test_plot_data():
), patch(
"matplotlib.pyplot.close"
):
created_plots = plot_data(data_success, headers, matched_headers)
plotting_tools.data = data_success
plotting_tools.headers = headers
plotting_tools.matched_headers = matched_headers
created_plots = plotting_tools.plot_data()
assert "time_vs_value1.png" in created_plots
assert "time_vs_value2.png" in created_plots

Expand All @@ -122,8 +134,12 @@ def test_plot_data():
{"Time": "2", "Value1": "C", "Value2": "D"},
]

plotting_tools.data = data_failure
plotting_tools.headers = headers
plotting_tools.matched_headers = matched_headers

with pytest.raises(Exception) as excinfo:
plot_data(data_failure, headers, matched_headers)
plotting_tools.plot_data()
assert "All plots failed due to non-numeric data." in str(excinfo.value)


Expand Down

0 comments on commit 7b613ad

Please sign in to comment.