-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from Linh-nk/gpt
get anomalies with chatgpt
- Loading branch information
Showing
10 changed files
with
611 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
Result post-processing module. | ||
This module contains functions that help convert model responses back to indices and timestamps. | ||
""" | ||
import numpy as np | ||
|
||
|
||
def str2sig(text, sep=',', decimal=0): | ||
"""Convert a text string to a signal. | ||
Convert a string containing digits into an array of numbers. | ||
Args: | ||
text (str): | ||
A string containing signal values. | ||
sep (str): | ||
String that was used to separate each element in text, Default to `","`. | ||
decimal (int): | ||
Number of decimal points to shift each element in text to. Default to `0`. | ||
Returns: | ||
numpy.ndarray: | ||
A 1-dimensional array containing parsed elements in `text`. | ||
""" | ||
# Remove all characters from text except the digits and sep and decimal point | ||
text = ''.join(i for i in text if (i.isdigit() or i == sep or i == '.')) | ||
values = np.fromstring(text, dtype=float, sep=sep) | ||
return values * 10**(-decimal) | ||
|
||
|
||
def str2idx(text, len_seq, sep=','): | ||
"""Convert a text string to indices. | ||
Convert a string containing digits into an array of indices. | ||
Args: | ||
text (str): | ||
A string containing indices values. | ||
len_seq (int): | ||
The length of processed sequence | ||
sep (str): | ||
String that was used to separate each element in text, Default to `","`. | ||
Returns: | ||
numpy.ndarray: | ||
A 1-dimensional array containing parsed elements in `text`. | ||
""" | ||
# Remove all characters from text except the digits and sep | ||
text = ''.join(i for i in text if (i.isdigit() or i == sep)) | ||
|
||
values = np.fromstring(text, dtype=int, sep=sep) | ||
|
||
# Remove indices that exceed the length of sequence | ||
values = values[values < len_seq] | ||
return values | ||
|
||
|
||
def get_anomaly_list_within_seq(res_list, alpha=0.5): | ||
"""Get the final list of anomalous indices of a sequence | ||
Choose anomalous index in the sequence based on multiple LLM responses | ||
Args: | ||
res_list (List[numpy.ndarray]): | ||
A list of 1-dimensional array containing anomous indices output by LLM | ||
alpha (float): | ||
Percentage of votes needed for an index to be deemed anomalous. Default: 0.5 | ||
Returns: | ||
numpy.ndarray: | ||
A 1-dimensional array containing final anomalous indices | ||
""" | ||
min_vote = np.ceil(alpha * len(res_list)) | ||
|
||
flattened_res = np.concatenate(res_list) | ||
|
||
unique_elements, counts = np.unique(flattened_res, return_counts=True) | ||
|
||
final_list = unique_elements[counts >= min_vote] | ||
|
||
return final_list | ||
|
||
|
||
def merge_anomaly_seq(anomalies, start_indices, window_size, step_size, beta=0.5): | ||
"""Get the final list of anomalous indices of a sequence when merging all rolling windows | ||
Args: | ||
anomalies (List[numpy.ndarray]): | ||
A list of 1-dimensional array containing anomous indices of each window | ||
start_indices (numpy.ndarray): | ||
A 1-dimensional array contaning the first index of each window | ||
window_size (int): | ||
Length of each window | ||
step_size (int): | ||
Indicating the number of steps the window moves forward each round. | ||
beta (float): | ||
Percentage of containing windows needed for index to be deemed anomalous. Default: 0.5 | ||
Return: | ||
numpy.ndarray: | ||
A 1-dimensional array containing final anomalous indices | ||
""" | ||
anomalies = [arr + first_idx for (arr, first_idx) in zip(anomalies, start_indices)] | ||
|
||
min_vote = np.ceil(beta * window_size / step_size) | ||
|
||
flattened_res = np.concatenate(anomalies) | ||
|
||
unique_elements, counts = np.unique(flattened_res, return_counts=True) | ||
|
||
final_list = unique_elements[counts >= min_vote] | ||
|
||
return np.sort(final_list) | ||
|
||
|
||
def idx2time(sequence, idx_list): | ||
"""Convert list of indices into list of timestamp | ||
Args: | ||
sequence (pandas.Dataframe): | ||
Signal with timestamps and values | ||
idx_list (numpy.ndarray): | ||
A 1-dimensional array of indices | ||
Returns: | ||
numpy.ndarray: | ||
A 1-dimensional array containing timestamps | ||
""" | ||
return sequence.iloc[idx_list].timestamp.to_numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
GPT model module. | ||
This module contains functions that are specifically used for GPT models | ||
""" | ||
import os | ||
|
||
from openai import OpenAI | ||
|
||
|
||
def load_system_prompt(file_path): | ||
with open(file_path) as f: | ||
system_prompt = f.read() | ||
return system_prompt | ||
|
||
|
||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
ZERO_SHOT_FILE = 'gpt_system_prompt_zero_shot.txt' | ||
ONE_SHOT_FILE = 'gpt_system_prompt_one_shot.txt' | ||
|
||
ZERO_SHOT_DIR = os.path.join(CURRENT_DIR, "..", "template", ZERO_SHOT_FILE) | ||
ONE_SHOT_DIR = os.path.join(CURRENT_DIR, "..", "template", ONE_SHOT_FILE) | ||
|
||
|
||
GPT_model = "gpt-4" # "gpt-4-0125-preview" # # #"gpt-3.5-turbo" # | ||
client = OpenAI() | ||
|
||
|
||
def get_gpt_model_response(message, gpt_model=GPT_model): | ||
completion = client.chat.completions.create( | ||
model=gpt_model, | ||
messages=message, | ||
) | ||
return completion.choices[0].message.content | ||
|
||
|
||
def create_message_zero_shot(seq_query, system_prompt_file=ZERO_SHOT_DIR): | ||
messages = [] | ||
|
||
messages.append({"role": "system", "content": load_system_prompt(system_prompt_file)}) | ||
|
||
# final prompt | ||
messages.append({"role": "user", "content": f"Sequence: {seq_query}"}) | ||
return messages | ||
|
||
|
||
def create_message_one_shot(seq_query, seq_ex, ano_idx_ex, system_prompt_file=ONE_SHOT_DIR): | ||
messages = [] | ||
|
||
messages.append({"role": "system", "content": load_system_prompt(system_prompt_file)}) | ||
|
||
# one shot | ||
messages.append({"role": "user", "content": f"Sequence: {seq_ex}"}) | ||
messages.append({"role": "assistant", "content": ano_idx_ex}) | ||
|
||
# final prompt | ||
messages.append({"role": "user", "content": f"Sequence: {seq_query}"}) | ||
return messages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,40 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Main module.""" | ||
""" | ||
Main module. | ||
This module contains functions that get LLM's anomaly detection results. | ||
""" | ||
from anomalies import get_anomaly_list_within_seq, str2idx | ||
from data import sig2str | ||
|
||
|
||
def get_anomalies(seq, msg_func, model_func, num_iters=1, alpha=0.5): | ||
"""Get LLM anomaly detection results. | ||
The function get the LLM's anomaly detection and converts them into an 1D array | ||
Args: | ||
seq (ndarray): | ||
The sequence to detect anomalies. | ||
msg_func (func): | ||
Function to create message prompt. | ||
model_func (func): | ||
Function to get LLM answer. | ||
num_iters (int): | ||
Number of times to run the same query. | ||
alpha (float): | ||
Percentage of total number of votes that an index needs to have to be | ||
considered anomalous. Default: 0.5 | ||
Returns: | ||
ndarray: | ||
1D array containing anomalous indices of the sequence. | ||
""" | ||
message = msg_func(sig2str(seq, space=True)) | ||
res_list = [] | ||
for i in range(num_iters): | ||
res = model_func(message) | ||
ano_ind = str2idx(res, len(seq)) | ||
res_list.append(ano_ind) | ||
return get_anomaly_list_within_seq(res_list, alpha=alpha) |
Oops, something went wrong.