-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
54 lines (42 loc) · 1.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Step 1: Sample sentences about climate change
sentences = [
"Climate change is causing more frequent extreme weather events.",
"The polar ice caps are melting at an alarming rate.",
"Rising sea levels threaten coastal communities worldwide.",
"Deforestation contributes significantly to climate change.",
"Renewable energy sources can help mitigate climate change.",
"Global warming is a key driver of climate change."
]
# Step 2: Generate sentence embeddings
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
embeddings = model.encode(sentences)
# Step 3: Create a similarity graph
similarity_matrix = cosine_similarity(embeddings)
threshold = 0.1 # Define a similarity threshold for creating edges
G = nx.Graph()
# Add nodes with sentences as labels
for i, sentence in enumerate(sentences):
G.add_node(i, sentence=sentence)
# Add edges based on similarity
for i in range(len(sentences)):
for j in range(i + 1, len(sentences)):
if similarity_matrix[i][j] > threshold:
G.add_edge(i, j, weight=similarity_matrix[i][j])
# Step 4: Visualize the graph
plt.figure(figsize=(10, 8))
# Use spring layout for better visualization
pos = nx.spring_layout(G, seed=42)
# Draw nodes with labels
nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=3000)
nx.draw_networkx_edges(G, pos, width=2.0, alpha=0.6)
nx.draw_networkx_labels(G, pos, {i: f"Sentence {i+1}" for i in G.nodes}, font_size=10)
# Display edge labels (weights)
edge_labels = {(i, j): f"{similarity_matrix[i][j]:.2f}" for i, j in G.edges}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_color='red')
plt.savefig("sentence_similarity_graph.png", dpi=300, bbox_inches='tight')
plt.title("Sentence Similarity Graph", fontsize=16)
plt.show()