-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmost_updatedscript.py
193 lines (145 loc) · 7.49 KB
/
most_updatedscript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#THIS IS THE ONE BEING EDITED AND WORKED ON
# Might also be worth checking with Michael to be sure GAMLSS is implemented correctly! I know it was quite complex.
from rpy2.robjects import r, pandas2ri
from rpy2.robjects.packages import importr
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr
from rpy2.robjects.conversion import localconverter
from rpy2.robjects.conversion import localconverter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem
from scipy import stats
# Activate the automatic conversion of rpy2 objects to pandas objects
pandas2ri.activate()
# Suppress R warnings
r('options(warn=-1)')
# Load required R libraries
gamlss = importr('gamlss')
r('library(mgcv)') # Load the mgcv package here
# Load the dataset
data_file = "/home/local/VANDERBILT/samirj/3rdyearproj/lgdata/COMBINED_FILTERED_QA.csv"
data = pd.read_csv(data_file, low_memory=False)
#print(f"Number of data points before filtering: {len(data)}") #this took so long: 37172
# Remove duplicates based on 'scan' and 'age' columns, give it data to keep rest of script working
data = data.drop_duplicates(subset=['scan', 'age'], keep='first')
# Print the number of data points after filtering
#print(f"Number of data points after filtering: {len(data)}") #this took so long: 37013
#159 points removed
# Display the filtered DataFrame (optional)
print(data)
# Define a function to process the data and fit a GAMLSS model
def process_data_and_fit_gamlss(data, tract, feature):
# Compute asymmetry: (Left - Right) / (Left + Right)
asym_col_name = f"{tract}_asymmetry_{feature}"
data = data.copy()
data.loc[:, asym_col_name] = (
data.loc[:, f"{tract}_left-{feature}"] - data.loc[:, f"{tract}_right-{feature}"]
) / (
data.loc[:, f"{tract}_left-{feature}"] + data.loc[:, f"{tract}_right-{feature}"]
)
# Convert pandas DataFrame to R dataframe
with localconverter(pandas2ri.converter) as cv:
r_data = pandas2ri.py2rpy(data)
# Fit the GAMLSS model using penalized B-splines for non-linear trends
gamlss_formula = r(f'{asym_col_name} ~ pb(age)')
gamlss_family = r('NO')
gamlss_model = r['gamlss'](formula=gamlss_formula, family=gamlss_family, data=r_data)
return gamlss_model, asym_col_name, data
# Step 1: Remove outliers more than 3 standard deviations away from the mean
def remove_outliers(data, asym_col_name):
# Calculate mean and standard deviation
mean = data[asym_col_name].mean()
std_dev = data[asym_col_name].std()
# Remove outliers: values more than 3 std deviations away from the mean
data = data[abs(data[asym_col_name] - mean) <= 3 * std_dev]
return data
# Function to extract GAMLSS fitted and predicted values
def get_gamlss_fitted_values(gamlss_model, ages):
new_data = robjects.DataFrame({'age': robjects.FloatVector(ages)})
predict_func = r['predict']
predictions = predict_func(gamlss_model, newdata=new_data, type='response')
return np.array(predictions)
# Separate your data by sex and define tract and feature variables
tract = input("Enter the tract name (e.g., AF, UF, SLF): ").strip()
feature = input("Enter the feature to analyze (e.g., volume, FA): ").strip()
# Define columns to keep, drop NAs, and drop duplicates
columns_to_keep = ['scan', 'subject', 'session', 'dataset', 'age', 'sex', f'{tract}_left-{feature}', f'{tract}_right-{feature}']
data = data[columns_to_keep]
data = data.dropna()
data = data.drop_duplicates()
# Get unique sorted ages
unique_ages = data['age'].unique()
unique_ages.sort()
# Split the data by sex
male_data = data[data['sex'] == 1]
female_data = data[data['sex'] == 0]
# Fit the models
# Fit the models
male_model, male_asym_col_name, male_data = process_data_and_fit_gamlss(male_data, tract, feature)
female_model, female_asym_col_name, female_data = process_data_and_fit_gamlss(female_data, tract, feature)
male_data = remove_outliers(male_data, male_asym_col_name) #filtered both male and female data
female_data = remove_outliers(female_data, female_asym_col_name)
# Get the fitted values
male_fitted_values = get_gamlss_fitted_values(male_model, unique_ages)
female_fitted_values = get_gamlss_fitted_values(female_model, unique_ages)
# Function to calculate the 95% Confidence Interval (CI) for asymmetry values
def calculate_confidence_interval(data, asym_col_name, confidence_level=0.95):
# Calculate mean and standard error of the mean (SEM)
mean = data[asym_col_name].mean()
sem_value = sem(data[asym_col_name])
# Calculate the confidence interval
ci_lower, ci_upper = stats.norm.interval(confidence_level, loc=mean, scale=sem_value)
return ci_lower, ci_upper, mean, sem_value
# Function to remove outliers based on CI bounds
def remove_outliers_based_on_ci(data, asym_col_name, ci_lower, ci_upper):
# Remove outliers: values outside the CI range
data_filtered = data[(data[asym_col_name] >= ci_lower) & (data[asym_col_name] <= ci_upper)]
return data_filtered
# Modify the 'process_data_and_fit_gamlss' function to include CI calculation and outlier removal
def process_data_and_fit_gamlss_with_ci(data, tract, feature, confidence_level=0.95):
# Compute asymmetry: (Left - Right) / (Left + Right)
asym_col_name = f"{tract}_asymmetry_{feature}"
data = data.copy()
data.loc[:, asym_col_name] = (
data.loc[:, f"{tract}_left-{feature}"] - data.loc[:, f"{tract}_right-{feature}"]
) / (
data.loc[:, f"{tract}_left-{feature}"] + data.loc[:, f"{tract}_right-{feature}"]
)
# Calculate the CI for the asymmetry column
ci_lower, ci_upper, mean, sem_value = calculate_confidence_interval(data, asym_col_name, confidence_level)
print(f"{asym_col_name} 95% CI: ({ci_lower:.2f}, {ci_upper:.2f})")
print(f"Mean: {mean:.2f}, SEM: {sem_value:.2f}")
# Return the CI values as well
return ci_lower, ci_upper, mean, sem_value
male_ci_lower, male_ci_upper, male_mean, male_sem = process_data_and_fit_gamlss_with_ci(male_data, tract, feature)
female_ci_lower, female_ci_upper, female_mean, female_sem = process_data_and_fit_gamlss_with_ci(female_data, tract, feature)
# Get the fitted values from the GAMLSS model after CI filtering
male_fitted_values = get_gamlss_fitted_values(male_model, unique_ages)
female_fitted_values = get_gamlss_fitted_values(female_model, unique_ages)
# Plotting the results with fitted trends and confidence intervals
plt.figure(figsize=(14, 7))
# Male plot
plt.subplot(1, 2, 1)
plt.scatter(male_data['age'], male_data[male_asym_col_name], label="Male Data", color="blue", alpha=0.1)
plt.plot(unique_ages, male_fitted_values, label="GAMLSS Fitted Trend", color="black")
# Add confidence intervals as a shaded region
plt.fill_between(unique_ages, male_ci_lower, male_ci_upper, color="blue", alpha=0.3, label="95% Confidence Interval")
plt.title(f"Male {feature} Asymmetry in the {tract} Tract")
plt.xlabel('Age')
plt.ylabel(f'{feature} Asymmetry ({tract})')
plt.legend()
# Female plot
plt.subplot(1, 2, 2)
plt.scatter(female_data['age'], female_data[female_asym_col_name], label="Female Data", color="red", alpha=0.1)
plt.plot(unique_ages, female_fitted_values, label="GAMLSS Fitted Trend", color="black")
# Add confidence intervals as a shaded region
plt.fill_between(unique_ages, female_ci_lower, female_ci_upper, color="red", alpha=0.3, label="95% Confidence Interval")
plt.title(f"Female {feature} Asymmetry in the {tract} Tract")
plt.xlabel('Age')
plt.ylabel(f'{feature} Asymmetry ({tract})')
plt.legend()
plt.tight_layout()
plt.show()