You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# import Rcpp functions for pair matching
source("../misc/Stephane_matching.R")
data_match=datadata_match$is_treated= as.logical(data_match$W)
data_match$pair_nb=NA# Optional weights for each covariate when computing the distances# WARNING: the order of the items in scaling needs to be the same as the order of the covariates (i.e. columns)scaling= rep(list(1),ncol(data_match))
names(scaling) = colnames(data_match)
# set the thresholds for each covariate, default is Inf (i.e. no matching)thresholds= rep(list(Inf),ncol(data_match))
names(thresholds) = colnames(data_match)
# set particular valuesthresholds$sex=0thresholds$age_cat=1thresholds$bmi_corrected=4relevant_fields= colnames(data_match)[which(unlist(thresholds)<Inf)]
relevant_fields= c(relevant_fields, "is_treated")
matched_df=data.frame()
total_nb_match=0count=0start_time= Sys.time()
pb= txtProgressBar(min=0, max= dim(data_match)[1], initial=0, char="=", style=3)
## | | | 0%
count=0N= nrow(data_match)
#--------- explore treated units ---------#treated_units= subset(data_match,is_treated)
control_units= subset(data_match,!is_treated)
N_treated= nrow(treated_units)
# if (N_treated==0){# next# }N_control= nrow(control_units)
# if (N_control==0){# next# }
cat("Number of treated units:", N_treated,"\nNumber of control units:", N_control,"\n")
## Number of treated units: 11855
## Number of control units: 234
#--------------------------------------------------------------------------------------------------------------## Compute the discrepanciesdiscrepancies= discrepancyMatrix(treated_units, control_units, thresholds, scaling)
# N_possible_matches = sum(rowSums(discrepancies<Inf)>0)# cat("Number of prospective matched treated units =", N_possible_matches,"\n")# if (N_possible_matches==0){# next# }#------------------ Force pair matching via bipartite maximal weighted matching -----------------#adj= (discrepancies<Inf)
edges_mat= which(adj,arr.ind=TRUE)
weights=1/(1+sapply(1:nrow(edges_mat),function(i)discrepancies[edges_mat[i,1],edges_mat[i,2]]))
edges_mat[,"col"] =edges_mat[,"col"] +N_treatededges_vector= c(t(edges_mat))
#-----------------------------------------------------------------------------# Build the graph from the list of edgesBG= make_bipartite_graph(c(rep(TRUE,N_treated),rep(FALSE,N_control)), edges=edges_vector)
MBM= maximum.bipartite.matching(BG,weights=weights)
# List the dates of the matched pairspairs_list=list()
N_matched=0for (iin1:N_treated){
if (!is.na(MBM$matching[i])){
N_matched=N_matched+1pairs_list[[N_matched]] = c(i,MBM$matching[i]-N_treated)
}
}
# Quick sanity check for matched pairsfor (iin1:N_matched){
total_nb_match=total_nb_match+1# save pair numbertreated_units[pairs_list[[i]][1],"pair_nb"] =total_nb_matchcontrol_units[pairs_list[[i]][2],"pair_nb"] =total_nb_matchmatched_df= rbind(matched_df,treated_units[pairs_list[[i]][1],])
matched_df= rbind(matched_df,control_units[pairs_list[[i]][2],])
# cat("\n-------------------- Matched pair", total_nb_match,"--------------------\n")# print(treated_units[pairs_list[[i]][1],relevant_fields])# print(control_units[pairs_list[[i]][2],relevant_fields])
}
count=count+1
setTxtProgressBar(pb,count)
print(Sys.time()-start_time)