|
2113 | 2113 | "# ######## Save to csv\n",
|
2114 | 2114 | "# lme_dataset_df_filtered_human_labelling.to_csv(\"lme_dataset_df_filtered_human_labelling.csv\", index=False)"
|
2115 | 2115 | ]
|
2116 |
| - }, |
2117 |
| - { |
2118 |
| - "cell_type": "markdown", |
2119 |
| - "metadata": {}, |
2120 |
| - "source": [ |
2121 |
| - "# Archive" |
2122 |
| - ] |
2123 |
| - }, |
2124 |
| - { |
2125 |
| - "cell_type": "code", |
2126 |
| - "execution_count": null, |
2127 |
| - "metadata": {}, |
2128 |
| - "outputs": [], |
2129 |
| - "source": [ |
2130 |
| - "# Deserialize the JSON strings back into lists of dictionaries\n", |
2131 |
| - "list1_dicts = json.loads(eval_minidataset_labelled.iloc[0][output_column_name])\n", |
2132 |
| - "list2_dicts = json.loads(eval_minidataset_labelled.iloc[1][output_column_name])\n", |
2133 |
| - "\n", |
2134 |
| - "list3_dicts = json.loads(eval_minidataset_labelled.iloc[6][output_column_name])\n", |
2135 |
| - "list4_dicts = json.loads(eval_minidataset_labelled.iloc[7][output_column_name])\n", |
2136 |
| - "list5_dicts = json.loads(eval_minidataset_labelled.iloc[12][output_column_name])\n", |
2137 |
| - "list6_dicts = json.loads(eval_minidataset_labelled.iloc[13][output_column_name])\n", |
2138 |
| - "\n", |
2139 |
| - "# Convert the dictionaries back into EntityNode objects\n", |
2140 |
| - "list1_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list1_dicts]\n", |
2141 |
| - "list2_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list2_dicts]\n", |
2142 |
| - "list3_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list3_dicts]\n", |
2143 |
| - "list4_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list4_dicts]\n", |
2144 |
| - "list5_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list5_dicts]\n", |
2145 |
| - "list6_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list6_dicts]\n", |
2146 |
| - "\n", |
2147 |
| - "# Now list1_nodes, list2_nodes, list3_nodes, and list4_nodes contain the deserialized EntityNode objects\n", |
2148 |
| - "\n", |
2149 |
| - "#Convert Each List Into A List Of Only The Node Names\n", |
2150 |
| - "list1_node_names = [node.name for node in list1_nodes]\n", |
2151 |
| - "list2_node_names = [node.name for node in list2_nodes]\n", |
2152 |
| - "list3_node_names = [node.name for node in list3_nodes]\n", |
2153 |
| - "list4_node_names = [node.name for node in list4_nodes]\n", |
2154 |
| - "list5_node_names = [node.name for node in list5_nodes]\n", |
2155 |
| - "list6_node_names = [node.name for node in list6_nodes]\n", |
2156 |
| - "\n", |
2157 |
| - "\n", |
2158 |
| - "#Print Out The Lists\n", |
2159 |
| - "print(list1_node_names)\n", |
2160 |
| - "print(list2_node_names)\n", |
2161 |
| - "print(list3_node_names)\n", |
2162 |
| - "print(list4_node_names)\n", |
2163 |
| - "print(list5_node_names)\n", |
2164 |
| - "print(list6_node_names)\n", |
2165 |
| - "\n", |
2166 |
| - "#NowCollectAndPrintTheNoteSummaries For Each List\n", |
2167 |
| - "\n", |
2168 |
| - "# Function to print node summaries\n", |
2169 |
| - "def print_node_summaries(node_list, list_name):\n", |
2170 |
| - " print(f\"Node summaries for {list_name}:\")\n", |
2171 |
| - " for node in node_list:\n", |
2172 |
| - " # Assuming each node has a 'name' and 'attributes' attribute\n", |
2173 |
| - " print(f\"Name: {node.name}, Summary: {node.summary}\")\n", |
2174 |
| - " print(\"\\n\")\n", |
2175 |
| - "\n", |
2176 |
| - "# Print node summaries for each list\n", |
2177 |
| - "print_node_summaries(list1_nodes, \"list1\")\n", |
2178 |
| - "print_node_summaries(list2_nodes, \"list2\")\n", |
2179 |
| - "print_node_summaries(list3_nodes, \"list3\")\n", |
2180 |
| - "print_node_summaries(list4_nodes, \"list4\")\n", |
2181 |
| - "print_node_summaries(list5_nodes, \"list5\")\n", |
2182 |
| - "print_node_summaries(list6_nodes, \"list6\")\n", |
2183 |
| - "\n" |
2184 |
| - ] |
2185 |
| - }, |
2186 |
| - { |
2187 |
| - "cell_type": "code", |
2188 |
| - "execution_count": 11, |
2189 |
| - "metadata": {}, |
2190 |
| - "outputs": [], |
2191 |
| - "source": [ |
2192 |
| - "# def get_output(row, model):\n", |
2193 |
| - "# \"\"\"\n", |
2194 |
| - "# Gets the output for a given model\n", |
2195 |
| - "# \"\"\"\n", |
2196 |
| - "# if row['task_name'] == \"extract_nodes\":\n", |
2197 |
| - "# return row['input_extracted_nodes']\n", |
2198 |
| - "# elif row['task_name'] == \"dedupe_nodes\":\n", |
2199 |
| - "# return row['input_existing_relevant_nodes']\n", |
2200 |
| - "# elif row['task_name'] == \"extract_edges\":\n", |
2201 |
| - "# return row['input_extracted_edges']\n", |
2202 |
| - "# elif row['task_name'] == \"dedupe_edges\":\n", |
2203 |
| - "# return row['input_existing_relevant_edges']\n", |
2204 |
| - "# elif row['task_name'] == \"extract_edge_dates\":\n", |
2205 |
| - "# return row['input_extracted_edge_dates']\n", |
2206 |
| - "# elif row['task_name'] == \"edge_invalidation\":\n", |
2207 |
| - "# return row['input_edge_invalidation']\n", |
2208 |
| - "\n", |
2209 |
| - "# def insert_gpt4o_answers(df):\n", |
2210 |
| - "# \"\"\"\n", |
2211 |
| - "# Inserts gpt4o answers by doing ingestion in the right order and filling extra input columns as needed\n", |
2212 |
| - "# \"\"\"\n", |
2213 |
| - "# for _, row in df.iterrows():\n", |
2214 |
| - "# # Get the output\n", |
2215 |
| - "# output_gpt4o = get_output(row, \"gpt4o\")" |
2216 |
| - ] |
2217 |
| - }, |
2218 |
| - { |
2219 |
| - "cell_type": "code", |
2220 |
| - "execution_count": 12, |
2221 |
| - "metadata": {}, |
2222 |
| - "outputs": [], |
2223 |
| - "source": [ |
2224 |
| - "# ######## Write a multisession to a txt file\n", |
2225 |
| - "# question_id = \"gpt4_93159ced_abs\"\n", |
2226 |
| - "# multi_session = lme_dataset_df[lme_dataset_df['question_id'] == question_id][\"haystack_sessions\"].iloc[0]\n", |
2227 |
| - "\n", |
2228 |
| - "\n", |
2229 |
| - "# with open(f'{question_id}.txt', 'w') as f:\n", |
2230 |
| - "\n", |
2231 |
| - "\n", |
2232 |
| - "# for session in multi_session:\n", |
2233 |
| - "# f.write(\"New session \" + \"*\"*200 + \"\\n\")\n", |
2234 |
| - "# for message in session:\n", |
2235 |
| - " \n", |
2236 |
| - "# f.write(f\"{message}\\n\")\n", |
2237 |
| - "# f.write(\"\\n\")" |
2238 |
| - ] |
2239 |
| - }, |
2240 |
| - { |
2241 |
| - "cell_type": "code", |
2242 |
| - "execution_count": 13, |
2243 |
| - "metadata": {}, |
2244 |
| - "outputs": [], |
2245 |
| - "source": [ |
2246 |
| - "# ######## Method to save all of the snippets (or only firsts/lasts) of the specified multi-sessions to a CSV file\n", |
2247 |
| - "\n", |
2248 |
| - "\n", |
2249 |
| - "# def make_messages_readable(messages):\n", |
2250 |
| - "# if len(messages) == 0:\n", |
2251 |
| - "# return []\n", |
2252 |
| - "# if messages == [None]:\n", |
2253 |
| - "# return None\n", |
2254 |
| - "# result_string = \"\"\n", |
2255 |
| - "# for message in messages:\n", |
2256 |
| - "# if message is None:\n", |
2257 |
| - "# continue\n", |
2258 |
| - "# result_string += \"|\"*80 + f\" {message['role']} \" + \"|\"*80 + \"\\n\\n\" + f\"{message['content']}\\n\\n\"\n", |
2259 |
| - "# return result_string\n", |
2260 |
| - "\n", |
2261 |
| - "\n", |
2262 |
| - "\n", |
2263 |
| - "# def handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages):\n", |
2264 |
| - "# if readableMessages:\n", |
2265 |
| - "# snippet['message'] = make_messages_readable([snippet['message']])\n", |
2266 |
| - "# if spreadOutPreviousMessages:\n", |
2267 |
| - "# for i in range(max_num_previous_messages):\n", |
2268 |
| - "# snippet[f\"previous_message_{i+1}\"] = make_messages_readable([snippet[f\"previous_message_{i+1}\"]])\n", |
2269 |
| - "# return snippet\n", |
2270 |
| - "# snippet['previous_messages'] = make_messages_readable(snippet['previous_messages'])\n", |
2271 |
| - "# return snippet\n", |
2272 |
| - " \n", |
2273 |
| - "# snippet['message'] = json.dumps(snippet['message'])\n", |
2274 |
| - "# if spreadOutPreviousMessages:\n", |
2275 |
| - "# for i in range(max_num_previous_messages):\n", |
2276 |
| - "# snippet[f\"previous_message_{i+1}\"] = json.dumps(snippet[f\"previous_message_{i+1}\"])\n", |
2277 |
| - "# return snippet\n", |
2278 |
| - "# snippet['previous_messages'] = json.dumps(snippet['previous_messages'])\n", |
2279 |
| - "# return snippet\n", |
2280 |
| - "\n", |
2281 |
| - "# def save_multi_session_snippets_to_csv(question_ids, max_num_previous_messages=5, mode=\"all\", readableMessages=False, spreadOutPreviousMessages=False, snippet_index=None):\n", |
2282 |
| - "# \"\"\"\n", |
2283 |
| - "# Creates a csv where each row is a \"snippet\" from longmemeval. A snippet is a message and set of previous messages.\n", |
2284 |
| - "\n", |
2285 |
| - "# mode:\n", |
2286 |
| - "# mode=\"all\" --> saves all possible snippets for the specified question_ids. \n", |
2287 |
| - "# mode=\"firsts_only\" --> saves only the first snippet for each question_id.\n", |
2288 |
| - "# mode=\"lasts_only\" --> saves only the last snippet for each question_id.\n", |
2289 |
| - "# mode=\"index_only\" --> saves only the snippet at snippet_index for each question_id.\n", |
2290 |
| - "\n", |
2291 |
| - "# readableMessages:\n", |
2292 |
| - "# readableMessages=True --> saves the messages in a readable format, intended for reading in Google sheets.\n", |
2293 |
| - "# readableMessages=False --> saves the messages in a json format.\n", |
2294 |
| - "\n", |
2295 |
| - "# spreadOutPreviousMessages:\n", |
2296 |
| - "# spreadOutPreviousMessages=True --> spreads out the previous messages across multiple columns.\n", |
2297 |
| - "# spreadOutPreviousMessages=False --> saves all previous messages in a single column.\n", |
2298 |
| - " \n", |
2299 |
| - "# \"\"\"\n", |
2300 |
| - "\n", |
2301 |
| - "# all_snippets = []\n", |
2302 |
| - "# indices = []\n", |
2303 |
| - "# for question_id in question_ids:\n", |
2304 |
| - " \n", |
2305 |
| - "# ######## First, lets combine the sessions and dates into a list of dicts\n", |
2306 |
| - "# row = lme_dataset_df[lme_dataset_df['question_id'] == question_id]\n", |
2307 |
| - "\n", |
2308 |
| - "# if row.empty:\n", |
2309 |
| - "# raise ValueError(f\"No question found with ID: {question_id}\")\n", |
2310 |
| - "\n", |
2311 |
| - "\n", |
2312 |
| - "# # Extract the haystack_sessions column value\n", |
2313 |
| - "# sessions = row['haystack_sessions'].iloc[0]\n", |
2314 |
| - "\n", |
2315 |
| - "# # Get the haystack_dates column value\n", |
2316 |
| - "# session_dates = row['haystack_dates'].iloc[0]\n", |
2317 |
| - "\n", |
2318 |
| - "# # Combine into list of dictionaries\n", |
2319 |
| - "# session_and_dates = [\n", |
2320 |
| - "# {\n", |
2321 |
| - "# \"session\": session,\n", |
2322 |
| - "# \"date\": datetime.strptime(date, \"%Y/%m/%d (%a) %H:%M\")\n", |
2323 |
| - "# } \n", |
2324 |
| - "# for session, date in zip(sessions, session_dates)\n", |
2325 |
| - "# ]\n", |
2326 |
| - "\n", |
2327 |
| - "# # Sort by date from earliest to latest\n", |
2328 |
| - "# session_and_dates.sort(key=lambda x: x[\"date\"])\n", |
2329 |
| - "\n", |
2330 |
| - "\n", |
2331 |
| - "# all_snippets_this_session = []\n", |
2332 |
| - "\n", |
2333 |
| - "# total_num_messages = sum([len(session_and_date[\"session\"]) for session_and_date in session_and_dates])\n", |
2334 |
| - "\n", |
2335 |
| - "# message_index_across_sessions = 0\n", |
2336 |
| - "# for session_index, session_and_date in enumerate(session_and_dates):\n", |
2337 |
| - "# for message_index_within_session, message in enumerate(session_and_date[\"session\"]):\n", |
2338 |
| - " \n", |
2339 |
| - "# num_previous_messages = min(max_num_previous_messages, message_index_across_sessions)\n", |
2340 |
| - "# previous_snippets = all_snippets_this_session[message_index_across_sessions-num_previous_messages:]\n", |
2341 |
| - "# previous_messages_only = [previous_snippet[\"message\"] for previous_snippet in previous_snippets]\n", |
2342 |
| - "\n", |
2343 |
| - "# snippet = {\n", |
2344 |
| - "# \"question_id\": question_id,\n", |
2345 |
| - "# \"multisession_index\": int(row.index[0]),\n", |
2346 |
| - "# \"session_index\": session_index,\n", |
2347 |
| - "# \"message_index_within_session\": message_index_within_session,\n", |
2348 |
| - "# \"session_date\": session_and_date[\"date\"],\n", |
2349 |
| - "# }\n", |
2350 |
| - "\n", |
2351 |
| - "# if spreadOutPreviousMessages:\n", |
2352 |
| - "# previous_messages_only_padded = previous_messages_only + [None] * (max_num_previous_messages - len(previous_messages_only))\n", |
2353 |
| - "# for i, prev_msg in enumerate(previous_messages_only_padded):\n", |
2354 |
| - "# snippet[f\"previous_message_{i+1}\"] = prev_msg\n", |
2355 |
| - "# snippet[\"message\"] = message\n", |
2356 |
| - "# else:\n", |
2357 |
| - "# snippet[\"message\"] = message\n", |
2358 |
| - "# snippet[\"previous_messages\"] = previous_messages_only\n", |
2359 |
| - "\n", |
2360 |
| - "# all_snippets_this_session.append(snippet)\n", |
2361 |
| - "# message_index_across_sessions += 1\n", |
2362 |
| - "\n", |
2363 |
| - "# if mode == \"firsts_only\":\n", |
2364 |
| - "# all_snippets_this_session = [all_snippets_this_session[0]]\n", |
2365 |
| - "# elif mode == \"lasts_only\":\n", |
2366 |
| - "# all_snippets_this_session = [all_snippets_this_session[-1]]\n", |
2367 |
| - "# elif mode == \"index_only\":\n", |
2368 |
| - "# all_snippets_this_session = [all_snippets_this_session[snippet_index]]\n", |
2369 |
| - "\n", |
2370 |
| - "# all_snippets.extend(all_snippets_this_session)\n", |
2371 |
| - "# indices.append(int(row.index[0]))\n", |
2372 |
| - "\n", |
2373 |
| - "\n", |
2374 |
| - "# filename = \"lme_samples_indices=\"\n", |
2375 |
| - "# indices.sort() \n", |
2376 |
| - "# num_indices = len(indices)\n", |
2377 |
| - "# if num_indices < 4:\n", |
2378 |
| - "# for index in indices:\n", |
2379 |
| - "# filename += f\"_{index}\"\n", |
2380 |
| - "# else:\n", |
2381 |
| - "# filename += f\"_{indices[0]}_etc_{indices[-1]}\"\n", |
2382 |
| - "\n", |
2383 |
| - "# if mode == \"firsts_only\":\n", |
2384 |
| - "# filename += \"_firsts_only\"\n", |
2385 |
| - "# elif mode == \"lasts_only\":\n", |
2386 |
| - "# filename += \"_lasts_only\"\n", |
2387 |
| - "# elif mode == \"index_only\":\n", |
2388 |
| - "# filename += f\"_index_only={snippet_index}\"\n", |
2389 |
| - "# if spreadOutPreviousMessages:\n", |
2390 |
| - "# filename += \"_spreadOutPreviousMessages\"\n", |
2391 |
| - "# if readableMessages:\n", |
2392 |
| - "# filename += \"_readable\"\n", |
2393 |
| - "# filename += \".csv\"\n", |
2394 |
| - " \n", |
2395 |
| - " \n", |
2396 |
| - "# with open(filename, \"w\", newline=\"\") as csvfile:\n", |
2397 |
| - "# writer = csv.DictWriter(csvfile, fieldnames=all_snippets[0].keys())\n", |
2398 |
| - "# writer.writeheader()\n", |
2399 |
| - "# for snippet in all_snippets:\n", |
2400 |
| - "# processed_snippet = handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages)\n", |
2401 |
| - "# writer.writerow(processed_snippet)" |
2402 |
| - ] |
2403 | 2116 | }
|
2404 | 2117 | ],
|
2405 | 2118 | "metadata": {
|
|
0 commit comments