Skip to content

Commit 3333d9b

Browse files
sundy1994xehu
andauthored
Yuxuan/liwc test (#341)
* add LIWC tests and loading checks * Small updates to documentation and adding .__version__ parameter (#332) * update examples hierarchy * Closes #318. * provide __version__ variable without setup.py * small fix --------- Co-authored-by: sundy1994 <[email protected]> * update docs in liwc.rst and build again * Revert "update docs in liwc.rst and build again" This reverts commit 009d33e. --------- Co-authored-by: Xinlan Emily Hu <[email protected]>
1 parent 4211f0d commit 3333d9b

File tree

5 files changed

+1773
-801
lines changed

5 files changed

+1773
-801
lines changed

.gitignore

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,4 @@ node_modules/
5858
# testing
5959
/output
6060
/vector_data
61-
test.py
62-
63-
64-
61+
test.py
Binary file not shown.

src/team_comm_tools/utils/check_embeddings.py

Lines changed: 99 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
tokenizer = AutoTokenizer.from_pretrained(MODEL)
2323
model_bert = AutoModelForSequenceClassification.from_pretrained(MODEL)
2424
os.environ["TOKENIZERS_PARALLELISM"] = "false"
25+
EMOJIS_TO_PRESERVE = {
26+
"(:", "(;", "):", "/:", ":(", ":)", ":/", ";)"
27+
}
2528

2629
# Check if embeddings exist
2730
def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, need_sentence: bool,
@@ -86,25 +89,11 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, ne
8689
# Read in the lexicons (helper function for generating the pickle file)
8790
def read_in_lexicons(directory, lexicons_dict):
8891
for filename in os.listdir(directory):
92+
if filename.startswith("."):
93+
continue
8994
with open(directory/filename, encoding = "mac_roman") as lexicons:
90-
if filename.startswith("."):
91-
continue
92-
lines = []
93-
for lexicon in lexicons:
94-
lexicon = lexicon.strip()
95-
96-
if '*' not in lexicon:
97-
lines.append(r"\b" + lexicon.replace("\n", "") + r"\b")
98-
else:
99-
# get rid of any cases of multiple repeat -- e.g., '**'
100-
pattern = re.compile(r'\*+')
101-
lexicon = pattern.sub('*', lexicon)
102-
lexicon = r"\b" + lexicon.replace("\n", "").replace("*", "") + r"\S*\b"
103-
104-
# build the final lexicon
105-
lines.append(r"\b" + lexicon.replace("\n", "").replace("*", "") + r"\S*\b")
106-
clean_name = re.sub('.txt', '', filename)
107-
lexicons_dict[clean_name] = "|".join(lines)
95+
clean_name = re.sub('.txt', '', filename)
96+
lexicons_dict[clean_name] = sort_words(lexicons)
10897

10998
def generate_lexicon_pkl():
11099
"""
@@ -172,38 +161,80 @@ def fix_abbreviations(dicTerm: str) -> str:
172161
else:
173162
return dicTerm
174163

175-
def is_valid_term(dicTerm):
164+
def is_valid_term(dicTerm: str) -> bool:
176165
"""
177166
Check if a dictionary term is valid.
178167
179-
This function returns `True` if the term matches the regex pattern and `False` otherwise.
168+
This functions returns True if the term matches the regex pattern and Flase otherwise.
180169
The regex pattern matches:
181-
182-
- Alphanumeric characters (a-z, A-Z, 0-9)
183-
- Valid symbols: `-`, `'`, `*`, `/`
184-
- The `*` symbol can appear only once at the end of a word
185-
- Emojis are valid only when they appear alone
186-
- The `/` symbol can appear only once after alphanumeric characters
170+
- Alphanumeric characters (a-zA-Z0-9)
171+
- Valid symbols: -, ', *, /
172+
- The * symbol can only appear once at the end of a word
173+
- 8 emojis are valid only when they appear alone
174+
- The / symbol can only appear once after alphanumeric characters
187175
- Spaces are allowed between valid words
188176
189-
:param dicTerm: The dictionary term to validate.
177+
:param dicTerm: The dictionary term
190178
:type dicTerm: str
191179
192-
:return: `True` if the term is valid, `False` otherwise.
180+
hi:) 120
181+
182+
:return: True/False
193183
:rtype: bool
194184
"""
195-
196-
# List of emojis to preserve
197-
emojis_to_preserve = {
198-
"(:", "(;", "):", "/:", ":(", ":)", ":/", ";)"
199-
}
200-
emoji_pattern = '|'.join(re.escape(emoji) for emoji in emojis_to_preserve)
185+
emoji_pattern = '|'.join(re.escape(emoji) for emoji in EMOJIS_TO_PRESERVE)
201186
alphanumeric_pattern = (
202187
fr"^([a-zA-Z0-9\-']+(\*|\/[a-zA-Z0-9\*]*)?|({emoji_pattern})\*?)( [a-zA-Z0-9\-']+(\*|\/[a-zA-Z0-9\*]*)?)*$"
203188
)
204189

205190
return bool(re.match(alphanumeric_pattern, dicTerm))
206191

192+
def sort_words(lexicons: list) -> str:
193+
"""
194+
Sorts the dictionary terms in a list.
195+
196+
This function sorts the dictionary terms in a list by their length in descending order.
197+
The hyphenated words are sorted first, followed by the non-hyphenated words.
198+
199+
:param dicTerms: List of dictionary terms
200+
:type dicTerms: list
201+
202+
:return: dicTerms
203+
:rtype: str
204+
"""
205+
hyphenated_words = []
206+
non_hyphenated_words = []
207+
for lexicon in lexicons:
208+
lexicon = lexicon.strip()
209+
lexicon = lexicon.replace("\n", "")
210+
if lexicon == '':
211+
continue
212+
length = len(lexicon)
213+
# no word boundaries for emojis
214+
if any(emoji in lexicon for emoji in EMOJIS_TO_PRESERVE):
215+
lexicon = lexicon.replace('(', r'\(').replace(')', r'\)') #.replace('/', r'\/')
216+
else:
217+
lexicon = lexicon.replace('(', r'\(').replace(')', r'\)')
218+
word_boundaries = r"\b", r"\b"
219+
if lexicon[-1] == "*":
220+
pattern = re.compile(r'\*+')
221+
lexicon = pattern.sub('*', lexicon)
222+
if not lexicon[-2].isalnum():
223+
word_boundaries = r"(?<!\w)", r"(?!\w)"
224+
lexicon = lexicon.replace("*", r"\S*")
225+
elif not lexicon[-1].isalnum():
226+
word_boundaries = r"(?<!\w)", r"(?!\w)"
227+
lexicon = lexicon.join(word_boundaries)
228+
if '-' in lexicon:
229+
hyphenated_words.append((lexicon, length))
230+
else:
231+
non_hyphenated_words.append((lexicon, length))
232+
hyphenated_words.sort(key=lambda x: x[1], reverse=True)
233+
non_hyphenated_words.sort(key=lambda x: x[1], reverse=True)
234+
sorted_words = hyphenated_words + non_hyphenated_words
235+
sorted_words = [lexicon for lexicon, _ in sorted_words]
236+
return '|'.join(sorted_words)
237+
207238
def load_liwc_dict(dicText: str) -> dict:
208239
"""
209240
Loads up a dictionary that is in the LIWC 2007/2015 format.
@@ -212,7 +243,18 @@ def load_liwc_dict(dicText: str) -> dict:
212243
This functions reads the content of a LIWC dictionary file in the official format,
213244
and convert it to a dictionary with lexicon: regular expression format.
214245
We assume the dicText has two parts: the header, which maps numbers to "category names,"
215-
and the body, which maps words in the lexicon to different category numbers, separated by a '%' sign.
246+
and the body, which maps words in the lexicon to different category numbers, separated by '%'.
247+
Below is an example:
248+
'''
249+
%
250+
1 function
251+
2 pronoun
252+
3 ppron
253+
%
254+
again 1 2
255+
against 1 2 3
256+
'''
257+
Note that the elements in each line are separated by '\t'.
216258
217259
:param dicText: The content of a .dic file
218260
:type dicText: str
@@ -221,42 +263,45 @@ def load_liwc_dict(dicText: str) -> dict:
221263
:rtype: dict
222264
"""
223265
dicSplit = dicText.split('%', 2)
266+
# check 2 '%' symbols
267+
if len(dicSplit) != 3:
268+
raise ValueError("Invalid dictionary file.")
224269
dicHeader, dicBody = dicSplit[1], dicSplit[2]
225270
# read headers
226271
catNameNumberMap = {}
227272
for line in dicHeader.splitlines():
228273
if line.strip() == '':
229274
continue
230275
lineSplit = line.strip().split('\t')
276+
# check header format: 1 function
277+
if len(lineSplit) != 2 or not lineSplit[0].isdigit():
278+
raise ValueError("Invalid dictionary file.")
231279
catNameNumberMap[lineSplit[0]] = lineSplit[1]
232280
# read body
233281
dicCategories = {}
234282
for line in dicBody.splitlines():
235283
lineSplit = line.strip().split('\t')
236-
dicTerm, catNums = lineSplit[0], lineSplit[1:]
237-
dicTerm = fix_abbreviations(dicTerm=' '.join(lineSplit[0].lower().strip().split()))
238-
dicTerm = dicTerm.strip()
239-
if dicTerm == '':
284+
# check body format: again 1 2
285+
if lineSplit != [''] and len(lineSplit) < 2:
240286
continue
241-
if not is_valid_term(dicTerm):
242-
warnings.warn(f"WARNING: invalid dict term: {dicTerm}, skipped")
243-
if '*' in dicTerm:
244-
# Replace consecutive asterisks with a single asterisk -- e.g., '**'->'*'
245-
pattern = re.compile(r'\*+')
246-
dicTerm = pattern.sub('*', dicTerm)
247-
dicTerm = r"\b" + dicTerm.replace("\n", "").replace("*", "") + r"\S*\b"
248-
elif '(' in dicTerm or ')' in dicTerm or '/' in dicTerm:
249-
dicTerm = dicTerm.replace("\n", "").replace('(', r'\(').replace(')', r'\)').replace('/', r'\/')
250-
else:
251-
dicTerm = r"\b" + dicTerm.replace("\n", "") + r"\b"
252-
287+
lexicon, catNums = lineSplit[0], lineSplit[1:]
288+
lexicon = fix_abbreviations(dicTerm=' '.join(lineSplit[0].lower().strip().split()))
289+
lexicon = lexicon.strip()
290+
if lexicon == '':
291+
continue
292+
if not is_valid_term(lexicon):
293+
warnings.warn(f"WARNING: invalid lexicon: {lexicon}, skipped")
294+
continue
295+
253296
for catNum in catNums:
254297
cat = catNameNumberMap[catNum]
255298
if cat not in dicCategories:
256-
dicCategories[cat] = dicTerm
299+
dicCategories[cat] = [lexicon]
257300
else:
258-
cur_dicTerm = dicCategories[cat]
259-
dicCategories[cat] = cur_dicTerm + "|" + dicTerm
301+
dicCategories[cat].append(lexicon)
302+
# sort the words in the dictionary
303+
for cat, lexicons in dicCategories.items():
304+
dicCategories[cat] = sort_words(lexicons)
260305
return dicCategories
261306

262307
def generate_certainty_pkl():

0 commit comments

Comments
 (0)