|
2113 | 2113 | "# ######## Save to csv\n", |
2114 | 2114 | "# lme_dataset_df_filtered_human_labelling.to_csv(\"lme_dataset_df_filtered_human_labelling.csv\", index=False)" |
2115 | 2115 | ] |
2116 | | - }, |
2117 | | - { |
2118 | | - "cell_type": "markdown", |
2119 | | - "metadata": {}, |
2120 | | - "source": [ |
2121 | | - "# Archive" |
2122 | | - ] |
2123 | | - }, |
2124 | | - { |
2125 | | - "cell_type": "code", |
2126 | | - "execution_count": null, |
2127 | | - "metadata": {}, |
2128 | | - "outputs": [], |
2129 | | - "source": [ |
2130 | | - "# Deserialize the JSON strings back into lists of dictionaries\n", |
2131 | | - "list1_dicts = json.loads(eval_minidataset_labelled.iloc[0][output_column_name])\n", |
2132 | | - "list2_dicts = json.loads(eval_minidataset_labelled.iloc[1][output_column_name])\n", |
2133 | | - "\n", |
2134 | | - "list3_dicts = json.loads(eval_minidataset_labelled.iloc[6][output_column_name])\n", |
2135 | | - "list4_dicts = json.loads(eval_minidataset_labelled.iloc[7][output_column_name])\n", |
2136 | | - "list5_dicts = json.loads(eval_minidataset_labelled.iloc[12][output_column_name])\n", |
2137 | | - "list6_dicts = json.loads(eval_minidataset_labelled.iloc[13][output_column_name])\n", |
2138 | | - "\n", |
2139 | | - "# Convert the dictionaries back into EntityNode objects\n", |
2140 | | - "list1_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list1_dicts]\n", |
2141 | | - "list2_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list2_dicts]\n", |
2142 | | - "list3_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list3_dicts]\n", |
2143 | | - "list4_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list4_dicts]\n", |
2144 | | - "list5_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list5_dicts]\n", |
2145 | | - "list6_nodes = [dict_to_entity(node_dict, EntityNode) for node_dict in list6_dicts]\n", |
2146 | | - "\n", |
2147 | | - "# Now list1_nodes, list2_nodes, list3_nodes, and list4_nodes contain the deserialized EntityNode objects\n", |
2148 | | - "\n", |
2149 | | - "#Convert Each List Into A List Of Only The Node Names\n", |
2150 | | - "list1_node_names = [node.name for node in list1_nodes]\n", |
2151 | | - "list2_node_names = [node.name for node in list2_nodes]\n", |
2152 | | - "list3_node_names = [node.name for node in list3_nodes]\n", |
2153 | | - "list4_node_names = [node.name for node in list4_nodes]\n", |
2154 | | - "list5_node_names = [node.name for node in list5_nodes]\n", |
2155 | | - "list6_node_names = [node.name for node in list6_nodes]\n", |
2156 | | - "\n", |
2157 | | - "\n", |
2158 | | - "#Print Out The Lists\n", |
2159 | | - "print(list1_node_names)\n", |
2160 | | - "print(list2_node_names)\n", |
2161 | | - "print(list3_node_names)\n", |
2162 | | - "print(list4_node_names)\n", |
2163 | | - "print(list5_node_names)\n", |
2164 | | - "print(list6_node_names)\n", |
2165 | | - "\n", |
2166 | | - "#NowCollectAndPrintTheNoteSummaries For Each List\n", |
2167 | | - "\n", |
2168 | | - "# Function to print node summaries\n", |
2169 | | - "def print_node_summaries(node_list, list_name):\n", |
2170 | | - " print(f\"Node summaries for {list_name}:\")\n", |
2171 | | - " for node in node_list:\n", |
2172 | | - " # Assuming each node has a 'name' and 'attributes' attribute\n", |
2173 | | - " print(f\"Name: {node.name}, Summary: {node.summary}\")\n", |
2174 | | - " print(\"\\n\")\n", |
2175 | | - "\n", |
2176 | | - "# Print node summaries for each list\n", |
2177 | | - "print_node_summaries(list1_nodes, \"list1\")\n", |
2178 | | - "print_node_summaries(list2_nodes, \"list2\")\n", |
2179 | | - "print_node_summaries(list3_nodes, \"list3\")\n", |
2180 | | - "print_node_summaries(list4_nodes, \"list4\")\n", |
2181 | | - "print_node_summaries(list5_nodes, \"list5\")\n", |
2182 | | - "print_node_summaries(list6_nodes, \"list6\")\n", |
2183 | | - "\n" |
2184 | | - ] |
2185 | | - }, |
2186 | | - { |
2187 | | - "cell_type": "code", |
2188 | | - "execution_count": 11, |
2189 | | - "metadata": {}, |
2190 | | - "outputs": [], |
2191 | | - "source": [ |
2192 | | - "# def get_output(row, model):\n", |
2193 | | - "# \"\"\"\n", |
2194 | | - "# Gets the output for a given model\n", |
2195 | | - "# \"\"\"\n", |
2196 | | - "# if row['task_name'] == \"extract_nodes\":\n", |
2197 | | - "# return row['input_extracted_nodes']\n", |
2198 | | - "# elif row['task_name'] == \"dedupe_nodes\":\n", |
2199 | | - "# return row['input_existing_relevant_nodes']\n", |
2200 | | - "# elif row['task_name'] == \"extract_edges\":\n", |
2201 | | - "# return row['input_extracted_edges']\n", |
2202 | | - "# elif row['task_name'] == \"dedupe_edges\":\n", |
2203 | | - "# return row['input_existing_relevant_edges']\n", |
2204 | | - "# elif row['task_name'] == \"extract_edge_dates\":\n", |
2205 | | - "# return row['input_extracted_edge_dates']\n", |
2206 | | - "# elif row['task_name'] == \"edge_invalidation\":\n", |
2207 | | - "# return row['input_edge_invalidation']\n", |
2208 | | - "\n", |
2209 | | - "# def insert_gpt4o_answers(df):\n", |
2210 | | - "# \"\"\"\n", |
2211 | | - "# Inserts gpt4o answers by doing ingestion in the right order and filling extra input columns as needed\n", |
2212 | | - "# \"\"\"\n", |
2213 | | - "# for _, row in df.iterrows():\n", |
2214 | | - "# # Get the output\n", |
2215 | | - "# output_gpt4o = get_output(row, \"gpt4o\")" |
2216 | | - ] |
2217 | | - }, |
2218 | | - { |
2219 | | - "cell_type": "code", |
2220 | | - "execution_count": 12, |
2221 | | - "metadata": {}, |
2222 | | - "outputs": [], |
2223 | | - "source": [ |
2224 | | - "# ######## Write a multisession to a txt file\n", |
2225 | | - "# question_id = \"gpt4_93159ced_abs\"\n", |
2226 | | - "# multi_session = lme_dataset_df[lme_dataset_df['question_id'] == question_id][\"haystack_sessions\"].iloc[0]\n", |
2227 | | - "\n", |
2228 | | - "\n", |
2229 | | - "# with open(f'{question_id}.txt', 'w') as f:\n", |
2230 | | - "\n", |
2231 | | - "\n", |
2232 | | - "# for session in multi_session:\n", |
2233 | | - "# f.write(\"New session \" + \"*\"*200 + \"\\n\")\n", |
2234 | | - "# for message in session:\n", |
2235 | | - " \n", |
2236 | | - "# f.write(f\"{message}\\n\")\n", |
2237 | | - "# f.write(\"\\n\")" |
2238 | | - ] |
2239 | | - }, |
2240 | | - { |
2241 | | - "cell_type": "code", |
2242 | | - "execution_count": 13, |
2243 | | - "metadata": {}, |
2244 | | - "outputs": [], |
2245 | | - "source": [ |
2246 | | - "# ######## Method to save all of the snippets (or only firsts/lasts) of the specified multi-sessions to a CSV file\n", |
2247 | | - "\n", |
2248 | | - "\n", |
2249 | | - "# def make_messages_readable(messages):\n", |
2250 | | - "# if len(messages) == 0:\n", |
2251 | | - "# return []\n", |
2252 | | - "# if messages == [None]:\n", |
2253 | | - "# return None\n", |
2254 | | - "# result_string = \"\"\n", |
2255 | | - "# for message in messages:\n", |
2256 | | - "# if message is None:\n", |
2257 | | - "# continue\n", |
2258 | | - "# result_string += \"|\"*80 + f\" {message['role']} \" + \"|\"*80 + \"\\n\\n\" + f\"{message['content']}\\n\\n\"\n", |
2259 | | - "# return result_string\n", |
2260 | | - "\n", |
2261 | | - "\n", |
2262 | | - "\n", |
2263 | | - "# def handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages):\n", |
2264 | | - "# if readableMessages:\n", |
2265 | | - "# snippet['message'] = make_messages_readable([snippet['message']])\n", |
2266 | | - "# if spreadOutPreviousMessages:\n", |
2267 | | - "# for i in range(max_num_previous_messages):\n", |
2268 | | - "# snippet[f\"previous_message_{i+1}\"] = make_messages_readable([snippet[f\"previous_message_{i+1}\"]])\n", |
2269 | | - "# return snippet\n", |
2270 | | - "# snippet['previous_messages'] = make_messages_readable(snippet['previous_messages'])\n", |
2271 | | - "# return snippet\n", |
2272 | | - " \n", |
2273 | | - "# snippet['message'] = json.dumps(snippet['message'])\n", |
2274 | | - "# if spreadOutPreviousMessages:\n", |
2275 | | - "# for i in range(max_num_previous_messages):\n", |
2276 | | - "# snippet[f\"previous_message_{i+1}\"] = json.dumps(snippet[f\"previous_message_{i+1}\"])\n", |
2277 | | - "# return snippet\n", |
2278 | | - "# snippet['previous_messages'] = json.dumps(snippet['previous_messages'])\n", |
2279 | | - "# return snippet\n", |
2280 | | - "\n", |
2281 | | - "# def save_multi_session_snippets_to_csv(question_ids, max_num_previous_messages=5, mode=\"all\", readableMessages=False, spreadOutPreviousMessages=False, snippet_index=None):\n", |
2282 | | - "# \"\"\"\n", |
2283 | | - "# Creates a csv where each row is a \"snippet\" from longmemeval. A snippet is a message and set of previous messages.\n", |
2284 | | - "\n", |
2285 | | - "# mode:\n", |
2286 | | - "# mode=\"all\" --> saves all possible snippets for the specified question_ids. \n", |
2287 | | - "# mode=\"firsts_only\" --> saves only the first snippet for each question_id.\n", |
2288 | | - "# mode=\"lasts_only\" --> saves only the last snippet for each question_id.\n", |
2289 | | - "# mode=\"index_only\" --> saves only the snippet at snippet_index for each question_id.\n", |
2290 | | - "\n", |
2291 | | - "# readableMessages:\n", |
2292 | | - "# readableMessages=True --> saves the messages in a readable format, intended for reading in Google sheets.\n", |
2293 | | - "# readableMessages=False --> saves the messages in a json format.\n", |
2294 | | - "\n", |
2295 | | - "# spreadOutPreviousMessages:\n", |
2296 | | - "# spreadOutPreviousMessages=True --> spreads out the previous messages across multiple columns.\n", |
2297 | | - "# spreadOutPreviousMessages=False --> saves all previous messages in a single column.\n", |
2298 | | - " \n", |
2299 | | - "# \"\"\"\n", |
2300 | | - "\n", |
2301 | | - "# all_snippets = []\n", |
2302 | | - "# indices = []\n", |
2303 | | - "# for question_id in question_ids:\n", |
2304 | | - " \n", |
2305 | | - "# ######## First, lets combine the sessions and dates into a list of dicts\n", |
2306 | | - "# row = lme_dataset_df[lme_dataset_df['question_id'] == question_id]\n", |
2307 | | - "\n", |
2308 | | - "# if row.empty:\n", |
2309 | | - "# raise ValueError(f\"No question found with ID: {question_id}\")\n", |
2310 | | - "\n", |
2311 | | - "\n", |
2312 | | - "# # Extract the haystack_sessions column value\n", |
2313 | | - "# sessions = row['haystack_sessions'].iloc[0]\n", |
2314 | | - "\n", |
2315 | | - "# # Get the haystack_dates column value\n", |
2316 | | - "# session_dates = row['haystack_dates'].iloc[0]\n", |
2317 | | - "\n", |
2318 | | - "# # Combine into list of dictionaries\n", |
2319 | | - "# session_and_dates = [\n", |
2320 | | - "# {\n", |
2321 | | - "# \"session\": session,\n", |
2322 | | - "# \"date\": datetime.strptime(date, \"%Y/%m/%d (%a) %H:%M\")\n", |
2323 | | - "# } \n", |
2324 | | - "# for session, date in zip(sessions, session_dates)\n", |
2325 | | - "# ]\n", |
2326 | | - "\n", |
2327 | | - "# # Sort by date from earliest to latest\n", |
2328 | | - "# session_and_dates.sort(key=lambda x: x[\"date\"])\n", |
2329 | | - "\n", |
2330 | | - "\n", |
2331 | | - "# all_snippets_this_session = []\n", |
2332 | | - "\n", |
2333 | | - "# total_num_messages = sum([len(session_and_date[\"session\"]) for session_and_date in session_and_dates])\n", |
2334 | | - "\n", |
2335 | | - "# message_index_across_sessions = 0\n", |
2336 | | - "# for session_index, session_and_date in enumerate(session_and_dates):\n", |
2337 | | - "# for message_index_within_session, message in enumerate(session_and_date[\"session\"]):\n", |
2338 | | - " \n", |
2339 | | - "# num_previous_messages = min(max_num_previous_messages, message_index_across_sessions)\n", |
2340 | | - "# previous_snippets = all_snippets_this_session[message_index_across_sessions-num_previous_messages:]\n", |
2341 | | - "# previous_messages_only = [previous_snippet[\"message\"] for previous_snippet in previous_snippets]\n", |
2342 | | - "\n", |
2343 | | - "# snippet = {\n", |
2344 | | - "# \"question_id\": question_id,\n", |
2345 | | - "# \"multisession_index\": int(row.index[0]),\n", |
2346 | | - "# \"session_index\": session_index,\n", |
2347 | | - "# \"message_index_within_session\": message_index_within_session,\n", |
2348 | | - "# \"session_date\": session_and_date[\"date\"],\n", |
2349 | | - "# }\n", |
2350 | | - "\n", |
2351 | | - "# if spreadOutPreviousMessages:\n", |
2352 | | - "# previous_messages_only_padded = previous_messages_only + [None] * (max_num_previous_messages - len(previous_messages_only))\n", |
2353 | | - "# for i, prev_msg in enumerate(previous_messages_only_padded):\n", |
2354 | | - "# snippet[f\"previous_message_{i+1}\"] = prev_msg\n", |
2355 | | - "# snippet[\"message\"] = message\n", |
2356 | | - "# else:\n", |
2357 | | - "# snippet[\"message\"] = message\n", |
2358 | | - "# snippet[\"previous_messages\"] = previous_messages_only\n", |
2359 | | - "\n", |
2360 | | - "# all_snippets_this_session.append(snippet)\n", |
2361 | | - "# message_index_across_sessions += 1\n", |
2362 | | - "\n", |
2363 | | - "# if mode == \"firsts_only\":\n", |
2364 | | - "# all_snippets_this_session = [all_snippets_this_session[0]]\n", |
2365 | | - "# elif mode == \"lasts_only\":\n", |
2366 | | - "# all_snippets_this_session = [all_snippets_this_session[-1]]\n", |
2367 | | - "# elif mode == \"index_only\":\n", |
2368 | | - "# all_snippets_this_session = [all_snippets_this_session[snippet_index]]\n", |
2369 | | - "\n", |
2370 | | - "# all_snippets.extend(all_snippets_this_session)\n", |
2371 | | - "# indices.append(int(row.index[0]))\n", |
2372 | | - "\n", |
2373 | | - "\n", |
2374 | | - "# filename = \"lme_samples_indices=\"\n", |
2375 | | - "# indices.sort() \n", |
2376 | | - "# num_indices = len(indices)\n", |
2377 | | - "# if num_indices < 4:\n", |
2378 | | - "# for index in indices:\n", |
2379 | | - "# filename += f\"_{index}\"\n", |
2380 | | - "# else:\n", |
2381 | | - "# filename += f\"_{indices[0]}_etc_{indices[-1]}\"\n", |
2382 | | - "\n", |
2383 | | - "# if mode == \"firsts_only\":\n", |
2384 | | - "# filename += \"_firsts_only\"\n", |
2385 | | - "# elif mode == \"lasts_only\":\n", |
2386 | | - "# filename += \"_lasts_only\"\n", |
2387 | | - "# elif mode == \"index_only\":\n", |
2388 | | - "# filename += f\"_index_only={snippet_index}\"\n", |
2389 | | - "# if spreadOutPreviousMessages:\n", |
2390 | | - "# filename += \"_spreadOutPreviousMessages\"\n", |
2391 | | - "# if readableMessages:\n", |
2392 | | - "# filename += \"_readable\"\n", |
2393 | | - "# filename += \".csv\"\n", |
2394 | | - " \n", |
2395 | | - " \n", |
2396 | | - "# with open(filename, \"w\", newline=\"\") as csvfile:\n", |
2397 | | - "# writer = csv.DictWriter(csvfile, fieldnames=all_snippets[0].keys())\n", |
2398 | | - "# writer.writeheader()\n", |
2399 | | - "# for snippet in all_snippets:\n", |
2400 | | - "# processed_snippet = handle_message_readability(snippet, readableMessages, spreadOutPreviousMessages, max_num_previous_messages)\n", |
2401 | | - "# writer.writerow(processed_snippet)" |
2402 | | - ] |
2403 | 2116 | } |
2404 | 2117 | ], |
2405 | 2118 | "metadata": { |
|
0 commit comments