Skip to content

Commit f36e68e

Browse files
Merge pull request #316 from kartikbhtt7/dev
Implemented BERTopic Model for topic segmentation
2 parents 5de7101 + 8f052c6 commit f36e68e

File tree

7 files changed

+166
-0
lines changed

7 files changed

+166
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Base image
2+
FROM python:3.9
3+
4+
# Set the working directory
5+
WORKDIR /app
6+
7+
# Install dependencies
8+
COPY requirements.txt .
9+
RUN pip install -r requirements.txt
10+
11+
# Copy all source code
12+
COPY . .
13+
COPY . /app/
14+
15+
# Expose port for the server
16+
EXPOSE 8000
17+
18+
# Command to run the server
19+
CMD ["hypercorn", "--bind", "0.0.0.0:8000", "api:app"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
## BERTopic Topic Extraction Model
2+
3+
### Purpose :
4+
Model to extract meaningful segmentations of a query dataset
5+
6+
### Testing the model deployment :
7+
To run for testing of the model for topic head generation, follow the given below steps:
8+
9+
- Git clone the repo
10+
- Go to current folder location i.e. ``` cd src/topic_modelling/BERTopic ```
11+
- Create docker image file and test the api:
12+
#### (IMP) The input .csv file must have one column having preprocessed text and column name as 'text'
13+
'''
14+
docker build -t testmodel .
15+
docker run -p 8000:8000 testmodel
16+
curl -X POST -F "test.csv" http://localhost:8000/embed -o output4.csv
17+
'''
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .request import *
2+
from .model import *

src/topic_modelling/BERTopic/api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import io
3+
import json
4+
import pandas as pd
5+
from quart import Quart, request, Response, send_file
6+
from model import Model
7+
from request import ModelRequest
8+
9+
app = Quart(__name__)
10+
11+
# Initialize the model to be used for inference.
12+
model = None
13+
14+
@app.before_serving
15+
async def startup():
16+
"""This function is called once before the server starts to initialize the model."""
17+
global model
18+
model = Model(app)
19+
20+
@app.route('/embed', methods=['POST'])
21+
async def embed():
22+
"""This endpoint receives a CSV file, extracts text data from it, and uses the model to generate embeddings and topic information."""
23+
global model
24+
25+
files = await request.files # Get the uploaded files
26+
uploaded_file = files.get('file') # Get the uploaded CSV file
27+
28+
if not uploaded_file:
29+
return Response(json.dumps({"error": "No file uploaded"}), status=400, mimetype='application/json')
30+
31+
# Read the CSV file into a DataFrame
32+
csv_data = pd.read_csv(io.BytesIO(uploaded_file.stream.read()))
33+
34+
# Extract the text data
35+
text_data = csv_data['text'].tolist()
36+
37+
# Create a ModelRequest object with the extracted text data
38+
req = ModelRequest(text=text_data)
39+
40+
# Call the model's inference method and get the response
41+
response = await model.inference(req)
42+
43+
if response is None:
44+
# If an error occurred during inference, return an error response
45+
return Response(json.dumps({"error": "Inference error"}), status=500, mimetype='application/json')
46+
47+
# Convert the CSV string from the response into a DataFrame
48+
df = pd.read_csv(io.StringIO(response))
49+
50+
# Save the DataFrame to a CSV file
51+
output_file_path = 'output.csv'
52+
df.to_csv(output_file_path, index=False)
53+
54+
# Send the CSV file back as a download response
55+
return await send_file(output_file_path, mimetype='text/csv', as_attachment=True, attachment_filename='output.csv')

src/topic_modelling/BERTopic/model.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pandas as pd
2+
from sentence_transformers import SentenceTransformer
3+
from bertopic import BERTopic
4+
from umap import UMAP
5+
from sklearn.feature_extraction.text import CountVectorizer
6+
import json
7+
import nltk
8+
from request import ModelRequest
9+
10+
nltk.download("punkt")
11+
12+
class Model:
13+
def __init__(self, context):
14+
self.context = context
15+
self.sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
16+
self.vectorizer_model = CountVectorizer(stop_words="english")
17+
self.umap_model = UMAP(n_neighbors=15, min_dist=0.0, metric="cosine", random_state=69)
18+
# self.hdbscan_model = HDBSCAN(min_cluster_size=15, metric="euclidean", prediction_data=True)
19+
self.topic_model = BERTopic(
20+
umap_model = self.umap_model,
21+
# hdbscan_model = self.hdbscan_model,
22+
vectorizer_model = self.vectorizer_model,
23+
)
24+
25+
async def inference(self, request: ModelRequest):
26+
text = request.text
27+
try:
28+
# Encode the text using SentenceTransformer
29+
corpus_embeddings = self.sentence_model.encode(text)
30+
31+
# Fit the topic model
32+
topics, probabilities = self.topic_model.fit_transform(text, corpus_embeddings)
33+
34+
# Get topic information and cluster labels
35+
df_classes = self.topic_model.get_topic_info()
36+
cluster_labels, _ = self.topic_model.transform(text, corpus_embeddings)
37+
38+
df_result = pd.DataFrame({
39+
"document_text": text,
40+
"predicted_class_label": cluster_labels,
41+
"probabilities": probabilities,
42+
})
43+
44+
# Mapping cluster names to topic labels
45+
cluster_names_map = dict(zip(df_classes["Topic"], df_classes["Name"]))
46+
df_result["predicted_class_name"] = df_result["predicted_class_label"].map(cluster_names_map)
47+
48+
csv_string = df_result.to_csv(index=False)
49+
50+
except Exception as e:
51+
# Log & print the error
52+
print(f"Error during inference: {e}")
53+
return None
54+
55+
return csv_string
56+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import json
2+
3+
class ModelRequest():
4+
def __init__(self, text):
5+
self.text = text
6+
7+
def to_json(self):
8+
return json.dumps(self, default=lambda o: o.__dict__,
9+
sort_keys=True, indent=4)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
quart
2+
aiohttp
3+
pandas
4+
bertopic
5+
sentence_transformers
6+
numpy
7+
nltk
8+
scikit-learn

0 commit comments

Comments
 (0)