-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #317 from Shubh-Goyal-07/restructure
Updates NER
- Loading branch information
Showing
7 changed files
with
282 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,42 @@ | ||
## NER: | ||
|
||
|
||
### Purpose : | ||
|
||
Model to detect | ||
|
||
- crops | ||
- pests | ||
- seed type | ||
- seed type | ||
- time | ||
- phone numbers | ||
- numbers with units | ||
- dates | ||
|
||
### Testing the model deployment : | ||
|
||
### Testing the model deployment : | ||
To run for testing just the Hugging Face deployment for grievence recognition, you can follow the following steps : | ||
To run for testing just the Hugging Face deployment for grievence recognition, you can follow the following steps : | ||
|
||
- Git clone the repo | ||
- Go to current folder location i.e. ``` cd /src/ner/agri_ner_akai/local ``` | ||
- Create docker image file and test the api: | ||
- Go to current folder location i.e. ``cd /src/ner/agri_ner_akai/local`` | ||
- Create docker image file and test the api: | ||
|
||
``` | ||
docker build -t testmodel . | ||
docker run -p 8000:8000 testmodel | ||
curl -X POST -H "Content-Type: application/json" -d '{"text": "What are tomatoes and potaotes that are being attacked by aphids? "}' http://localhost:8000/ | ||
``` | ||
|
||
### **Request** | ||
|
||
``` | ||
curl -X POST -H "Content-Type: application/json" -d '{ | ||
"text": "What are tomatoes and potaotes that are being attacked by aphids will be treated next monday?", | ||
"type": ["email", "CROP"] | ||
}' http://localhost:8000/ | ||
``` | ||
|
||
``` | ||
curl -X POST -H "Content-Type: application/json" -d '{ | ||
"text": "What are tomatoes and potaotes that are being attacked by aphids? " | ||
}' http://localhost:8000/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from transformers import pipeline | ||
from request import ModelRequest | ||
|
||
class BertNERModel(): | ||
def __new__(cls): | ||
if not hasattr(cls, 'instance'): | ||
cls.instance = super(BertNERModel, cls).__new__(cls) | ||
cls.nlp_ner = pipeline("ner", model="GautamR/akai_ner", tokenizer="GautamR/akai_ner") | ||
return cls.instance | ||
|
||
def inference(self, sentence): | ||
entities = self.nlp_ner(sentence) | ||
return self.aggregate_entities(sentence, entities) | ||
|
||
@staticmethod | ||
def aggregate_entities(sentence, entity_outputs): | ||
aggregated_entities = [] | ||
current_entity = None | ||
|
||
for entity in entity_outputs: | ||
entity_type = entity["entity"].split("-")[-1] | ||
|
||
# Handle subwords | ||
if entity["word"].startswith("##"): | ||
# If we encounter an I-PEST or any other I- entity | ||
if "I-" in entity["entity"]: | ||
if current_entity: # Add previous entity | ||
aggregated_entities.append(current_entity) | ||
|
||
word_start = sentence.rfind(" ", 0, entity["start"]) + 1 | ||
word_end = sentence.find(" ", entity["end"]) | ||
if word_end == -1: | ||
word_end = len(sentence) | ||
|
||
current_entity = { | ||
"entity_group": entity_type, | ||
"score": float(entity["score"]), | ||
"word": sentence[word_start:word_end].replace('.','').replace('?',''), | ||
"start": float(word_start), | ||
"end": float(word_end) | ||
} | ||
aggregated_entities.append(current_entity) | ||
current_entity = None | ||
|
||
else: | ||
if current_entity: | ||
# If it's a subword but not an I- entity | ||
current_entity["word"] += entity["word"][2:] | ||
current_entity["end"] = entity["end"] | ||
current_entity["score"] = float((current_entity["score"] + entity["score"]) / 2) # averaging scores | ||
|
||
# Handle full words | ||
else: | ||
if current_entity: | ||
aggregated_entities.append(current_entity) | ||
|
||
current_entity = { | ||
"entity_group": entity_type, | ||
"score": float(entity["score"]), | ||
"word": entity["word"], | ||
"start": float(entity["start"]), | ||
"end": float(entity["end"]) | ||
} | ||
|
||
if current_entity: | ||
aggregated_entities.append(current_entity) | ||
|
||
return aggregated_entities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,51 @@ | ||
from transformers import pipeline | ||
from request import ModelRequest | ||
from regex_parse_ner import RegNERModel | ||
from bert_ner import BertNERModel | ||
|
||
class Model(): | ||
def __new__(cls, context): | ||
cls.context = context | ||
if not hasattr(cls, 'instance'): | ||
cls.instance = super(Model, cls).__new__(cls) | ||
cls.nlp_ner = pipeline("ner", model="GautamR/akai_ner", tokenizer="GautamR/akai_ner") | ||
return cls.instance | ||
def __init__(self, context): | ||
self.context = context | ||
print("Loading models...") | ||
self.regex_model = RegNERModel() | ||
print("Regex model loaded successfully") | ||
self.bert_model = BertNERModel() | ||
print("Bert model loaded successfully") | ||
|
||
async def inference(self, request: ModelRequest): | ||
entities = self.nlp_ner(request.text) | ||
return self.aggregate_entities(request.text, entities) | ||
def combine_entities(self, reg_entities, bert_entities): | ||
combined_entities = reg_entities | ||
|
||
for entity in bert_entities: | ||
if entity['entity_group'] not in combined_entities: | ||
combined_entities[entity['entity_group']] = [] | ||
|
||
@staticmethod | ||
def aggregate_entities(sentence, entity_outputs): | ||
aggregated_entities = [] | ||
current_entity = None | ||
entity_info = { | ||
'name': entity['word'], | ||
'start': entity['start'], | ||
'end': entity['end'], | ||
'score': entity['score'] | ||
} | ||
|
||
for entity in entity_outputs: | ||
entity_type = entity["entity"].split("-")[-1] | ||
combined_entities[entity['entity_group']].append(entity_info) | ||
|
||
# Handle subwords | ||
if entity["word"].startswith("##"): | ||
# If we encounter an I-PEST or any other I- entity | ||
if "I-" in entity["entity"]: | ||
if current_entity: # Add previous entity | ||
aggregated_entities.append(current_entity) | ||
|
||
word_start = sentence.rfind(" ", 0, entity["start"]) + 1 | ||
word_end = sentence.find(" ", entity["end"]) | ||
if word_end == -1: | ||
word_end = len(sentence) | ||
return combined_entities | ||
|
||
async def inference(self, request: ModelRequest): | ||
sentence = request.text | ||
types = request.type | ||
|
||
current_entity = { | ||
"entity_group": entity_type, | ||
"score": float(entity["score"]), | ||
"word": sentence[word_start:word_end].replace('.','').replace('?',''), | ||
"start": float(word_start), | ||
"end": float(word_end) | ||
} | ||
aggregated_entities.append(current_entity) | ||
current_entity = None | ||
reg_entities = self.regex_model.inference(sentence) | ||
bert_entities = self.bert_model.inference(sentence) | ||
|
||
else: | ||
if current_entity: | ||
# If it's a subword but not an I- entity | ||
current_entity["word"] += entity["word"][2:] | ||
current_entity["end"] = entity["end"] | ||
current_entity["score"] = float((current_entity["score"] + entity["score"]) / 2) # averaging scores | ||
combined_entities = self.combine_entities(reg_entities, bert_entities) | ||
|
||
# Handle full words | ||
else: | ||
if current_entity: | ||
aggregated_entities.append(current_entity) | ||
final_entities = {} | ||
|
||
current_entity = { | ||
"entity_group": entity_type, | ||
"score": float(entity["score"]), | ||
"word": entity["word"], | ||
"start": float(entity["start"]), | ||
"end": float(entity["end"]) | ||
} | ||
if types is None: | ||
return combined_entities | ||
|
||
if current_entity: | ||
aggregated_entities.append(current_entity) | ||
for entity_group in combined_entities: | ||
if entity_group in types: | ||
final_entities[entity_group] = combined_entities[entity_group] | ||
|
||
return aggregated_entities | ||
return final_entities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import re | ||
import spacy | ||
from datetime import datetime, timedelta | ||
|
||
class RegNERModel(): | ||
def __init__(self): | ||
self.nlp = spacy.load("en_core_web_sm") | ||
|
||
print("Model loaded successfully") | ||
|
||
def detect_email(self, sentence): | ||
email_regex_pattern = '[A-Za-z0-9._%+-]*@[A-Za-z0-9.-]*\.[A-Z|a-z]*' | ||
emails_matches = [] | ||
|
||
for match in re.finditer(email_regex_pattern, sentence): | ||
emails_matches.append( {"name": match.group(), "start": match.start(), "end": match.end(), "score": 1.0} ) | ||
|
||
return emails_matches | ||
|
||
def detect_time(self, sentence): | ||
time_regex = r'\b(?:1[0-2]|0?[1-9])(?::[0-5][0-9])?(?:\s?[ap]m)?\b' | ||
times = [] | ||
|
||
for match in re.finditer(time_regex, sentence, re.IGNORECASE): | ||
times.append( {"name": match.group(), "start": match.start(), "end": match.end(), "score": 1.0} ) | ||
|
||
return times | ||
|
||
def detect_phone_numbers(self, sentence): | ||
phone_regex = r'(\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4})' | ||
|
||
phone_numbers = [] | ||
for match in re.finditer(phone_regex, sentence): | ||
phone_numbers.append( {"name": match.group(), "start": match.start(), "end": match.end(), "score": 1.0} ) | ||
|
||
return phone_numbers | ||
|
||
def detect_numbers_with_units(self, sentence, phone_numbers): | ||
number_unit_regex = r'(?<!\d)(\d+(?:\.\d+)?)(?:\s+)(\w+)(?!\d)' | ||
|
||
numbers_with_units = [] | ||
|
||
for match in re.finditer(number_unit_regex, sentence): | ||
number, unit = match.groups() | ||
if number not in phone_numbers: | ||
numbers_with_units.append( {"name": f"{number} {unit}", "start": match.start(), "end": match.end(), "score": 1.0} ) | ||
|
||
return numbers_with_units | ||
|
||
def detect_dates(self, sentence): | ||
# Current date | ||
today = datetime.now() | ||
|
||
# Define regex patterns for relative date expressions | ||
patterns = [ | ||
r"(next|agle)\s+(monday|tuesday|wednesday|thursday|friday|saturday|sunday|somvar|mangalwar|budhwar|guruwar|shukrawar|shaniwar|raviwar)", | ||
r"(kal)", | ||
r"(next|agle)\s+(week|month|year|hafte|mahine|saal)" | ||
] | ||
|
||
# Initialize empty list to store detected dates | ||
detected_dates = [] | ||
|
||
# Iterate through patterns and search for matches in text | ||
for pattern in patterns: | ||
for matchdates in re.finditer(pattern, sentence.lower()): | ||
match = matchdates.groups() | ||
if match[0] in ['next', 'agle']: | ||
if match[1] in ['monday', 'somvar']: | ||
# Find next Monday | ||
days_until_weekday = (today.weekday() - 1) % 7 | ||
next_date = today + timedelta(days=days_until_weekday) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['tuesday', 'mangalwar']: | ||
# Find next Tuesday | ||
days_until_weekday = (today.weekday() - 0) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['wednesday', 'budhwar']: | ||
# Find next Wednesday | ||
days_until_weekday = (today.weekday() +1) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['thursday', 'guruwar']: | ||
# Find next Thursday | ||
days_until_weekday = (today.weekday() +2) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['friday', 'shukrawar']: | ||
# Find next Friday | ||
days_until_weekday = (today.weekday() +3) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['saturday', 'shaniwar']: | ||
# Find next Saturday | ||
days_until_weekday = (today.weekday() +4) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['sunday', 'raviwar']: | ||
# Find next Sunday | ||
days_until_weekday = (today.weekday() +5) % 7 | ||
next_date = today + timedelta(days=days_until_weekday ) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['week', 'hafte']: | ||
# Find next week | ||
next_date = today + timedelta(days=(7 - today.weekday())+6) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['month', 'mahine']: | ||
# Find next month | ||
next_date = today.replace(day=1, month=today.month+1) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[1] in ['year', 'saal']: | ||
# Find next year | ||
next_date = today.replace(day=1, month=1, year=today.year+1) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
elif match[0] == 'kal': | ||
# Find tomorrow's date | ||
next_date = today + timedelta(1) | ||
detected_dates.append({"name": next_date.strftime("%d-%m-%Y"), "start": matchdates.start(), "end": matchdates.end(), "score": 1.0}) | ||
|
||
return detected_dates | ||
|
||
def inference(self, sentence): | ||
detected_emails = self.detect_email(sentence) | ||
detected_time = self.detect_time(sentence) | ||
detected_phone_numbers = self.detect_phone_numbers(sentence) | ||
detected_number_units = self.detect_numbers_with_units(sentence, detected_phone_numbers) | ||
detected_dates = self.detect_dates(sentence) | ||
|
||
aggregated_entities = {} | ||
|
||
if detected_emails: | ||
aggregated_entities["email"] = detected_emails | ||
if detected_time: | ||
aggregated_entities["time"] = detected_time | ||
if detected_phone_numbers: | ||
aggregated_entities["phone_number"] = detected_phone_numbers | ||
if detected_number_units: | ||
aggregated_entities["number_with_unit"] = detected_number_units | ||
if detected_dates: | ||
aggregated_entities["date"] = detected_dates | ||
|
||
return aggregated_entities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.