1+ from enum import Enum
12import secrets
2- from typing import Any , Dict , Optional , Tuple , Union
3+ from typing import List , Optional , Dict , Any , Union , Tuple
34
45import requests
56from pydantic import BaseModel
67
7-
8- class DetectApiRequest (BaseModel ):
9- userInput : str
10- userInputBase64 : Optional [str ] = None
11- runHeuristicCheck : bool
12- runVectorCheck : bool
13- runLanguageModelCheck : bool
14- maxHeuristicScore : float
15- maxModelScore : float
16- maxVectorScore : float
17-
18-
19- class DetectApiSuccessResponse (BaseModel ):
20- heuristicScore : float
21- modelScore : float
22- vectorScore : Dict [str , float ]
23- runHeuristicCheck : bool
24- runVectorCheck : bool
25- runLanguageModelCheck : bool
26- maxHeuristicScore : float
27- maxModelScore : float
28- maxVectorScore : float
29- injectionDetected : bool
30-
31-
32- class ApiFailureResponse (BaseModel ):
33- error : str
34- message : str
8+ def to_camel (string : str ) -> str :
9+ string_split = string .split ("_" )
10+ return string_split [0 ] + "" .join (word .capitalize () for word in string_split [1 :])
11+
12+ class RebuffBaseModel (BaseModel ):
13+ class Config :
14+ alias_generator = to_camel
15+ populate_by_name = True
16+
17+
18+ class TacticName (Enum ):
19+ HEURISTIC = "heuristic"
20+ """
21+ A series of heuristics are used to determine whether the input is prompt injection.
22+ """
23+
24+ LANGUAGE_MODEL = "language_model"
25+ """
26+ A language model is asked if the input appears to be prompt injection.
27+ """
28+
29+ VECTOR_DB = "vector_db"
30+ """
31+ A vector database of known prompt injection attacks is queried for similarity.
32+ """
33+
34+ class TacticOverride (RebuffBaseModel ):
35+ """
36+ Override settings for a specific tactic.
37+ """
38+
39+ name : TacticName
40+ """
41+ The name of the tactic to override.
42+ """
43+
44+ threshold : Optional [float ] = None
45+ """
46+ The threshold to use for this tactic. If the score is above this threshold, the tactic will be considered detected.
47+ If not specified, the default threshold for the tactic will be used.
48+ """
49+
50+ run : Optional [bool ] = True
51+ """
52+ Whether to run this tactic. Defaults to true if not specified.
53+ """
54+
55+ class DetectRequest (RebuffBaseModel ):
56+ """
57+ Request to detect prompt injection.
58+ """
59+
60+ user_input : str
61+ """
62+ The user input to check for prompt injection.
63+ """
64+
65+ user_input_base64 : Optional [str ] = None
66+ """
67+ The base64-encoded user input. If this is specified, the user input will be ignored.
68+ """
69+
70+ tactic_overrides : Optional [List [TacticOverride ]] = None
71+ """
72+ Any tactics to change behavior for. If any tactic is not specified, the default threshold for that tactic will be used.
73+ """
74+
75+ class TacticResult (RebuffBaseModel ):
76+ """
77+ Result of a tactic execution.
78+ """
79+
80+ name : str
81+ """
82+ The name of the tactic.
83+ """
84+
85+ score : float
86+ """
87+ The score for the tactic. This is a number between 0 and 1. The closer to 1, the more likely that this is a prompt injection attempt.
88+ """
89+
90+ detected : bool
91+ """
92+ Whether this tactic evaluated the input as a prompt injection attempt.
93+ """
94+
95+ threshold : float
96+ """
97+ The threshold used for this tactic. If the score is above this threshold, the tactic will be considered detected.
98+ """
99+
100+ additional_fields : Dict [str , Any ]
101+ """
102+ Some tactics return additional fields:
103+ * "vector_db":
104+ - "countOverMaxVectorScore" (int): The number of different vectors whose similarity score is above the
105+ threshold.
106+ """
107+
108+ class DetectResponse (RebuffBaseModel ):
109+ """
110+ Response from a prompt injection detection request.
111+ """
112+
113+ injection_detected : bool
114+ """
115+ Whether prompt injection was detected.
116+ """
117+
118+ tactic_results : List [TacticResult ]
119+ """
120+ The result for each tactic that was executed.
121+ """
122+
123+ class ApiFailureResponse (Exception ):
124+ def __init__ (self , error : str , message : str ):
125+ super ().__init__ (f"Error: { error } , Message: { message } " )
126+ self .error = error
127+ self .message = message
35128
36129
37130class Rebuff :
38- def __init__ (self , api_token : str , api_url : str = "https://playground .rebuff.ai" ):
131+ def __init__ (self , api_token : str , api_url : str = "https://www .rebuff.ai/playground " ):
39132 self .api_token = api_token
40133 self .api_url = api_url
41134 self ._headers = {
@@ -46,63 +139,47 @@ def __init__(self, api_token: str, api_url: str = "https://playground.rebuff.ai"
46139 def detect_injection (
47140 self ,
48141 user_input : str ,
49- max_heuristic_score : float = 0.75 ,
50- max_vector_score : float = 0.90 ,
51- max_model_score : float = 0.9 ,
52- check_heuristic : bool = True ,
53- check_vector : bool = True ,
54- check_llm : bool = True ,
55- ) -> Union [DetectApiSuccessResponse , ApiFailureResponse ]:
142+ tactic_overrides : Optional [List [TacticOverride ]] = None ,
143+ ) -> DetectResponse :
56144 """
57145 Detects if the given user input contains an injection attempt.
58146
59147 Args:
60148 user_input (str): The user input to be checked for injection.
61- max_heuristic_score (float, optional): The maximum heuristic score allowed. Defaults to 0.75.
62- max_vector_score (float, optional): The maximum vector score allowed. Defaults to 0.90.
63- max_model_score (float, optional): The maximum model (LLM) score allowed. Defaults to 0.9.
64- check_heuristic (bool, optional): Whether to run the heuristic check. Defaults to True.
65- check_vector (bool, optional): Whether to run the vector check. Defaults to True.
66- check_llm (bool, optional): Whether to run the language model check. Defaults to True.
149+ tactic_overrides (Optional[List[TacticOverride]], optional): A list of tactics to override.
150+ If a tactic is not specified in this list, the default threshold for that tactic will be used.
67151
68152 Returns:
69- Tuple[Union[DetectApiSuccessResponse, ApiFailureResponse], bool]: A tuple containing the detection
70- metrics and a boolean indicating if an injection was detected.
153+ DetectResponse: An object containing the detection metrics and a boolean indicating if an injection was
154+ detected.
155+
156+ Example:
157+ >>> from rebuff import Rebuff, TacticOverride, TacticName
158+ >>> rb = Rebuff(api_token='your_api_token')
159+ >>> user_input = "Your user input here"
160+ >>> tactic_overrides = [
161+ ... TacticOverride(name=TacticName.HEURISTIC, threshold=0.6),
162+ ... TacticOverride(name=TacticName.LANGUAGE_MODEL, run=False),
163+ ... ]
164+ >>> response = rb.detect_injection(user_input, tactic_overrides)
71165 """
72- request_data = DetectApiRequest (
73- userInput = user_input ,
74- userInputBase64 = encode_string (user_input ),
75- runHeuristicCheck = check_heuristic ,
76- runVectorCheck = check_vector ,
77- runLanguageModelCheck = check_llm ,
78- maxVectorScore = max_vector_score ,
79- maxModelScore = max_model_score ,
80- maxHeuristicScore = max_heuristic_score ,
166+ request_data = DetectRequest (
167+ user_input = user_input ,
168+ user_input_base64 = encode_string (user_input ),
169+ tactic_overrides = tactic_overrides ,
81170 )
82171
83172 response = requests .post (
84173 f"{ self .api_url } /api/detect" ,
85- json = request_data .dict ( ),
174+ json = request_data .model_dump ( mode = "json" , by_alias = True , exclude_none = True ),
86175 headers = self ._headers ,
87176 )
88177
89- response .raise_for_status ()
90-
91178 response_json = response .json ()
92- success_response = DetectApiSuccessResponse .parse_obj (response_json )
93-
94- if (
95- success_response .heuristicScore > max_heuristic_score
96- or success_response .modelScore > max_model_score
97- or success_response .vectorScore ["topScore" ] > max_vector_score
98- ):
99- # Injection detected
100- success_response .injectionDetected = True
101- return success_response
102- else :
103- # No injection detected
104- success_response .injectionDetected = False
105- return success_response
179+ if "error" in response_json :
180+ raise ApiFailureResponse (response_json ["error" ], response_json .get ("message" , "No message provided" ))
181+ response .raise_for_status ()
182+ return DetectResponse .model_validate (response_json )
106183
107184 @staticmethod
108185 def generate_canary_word (length : int = 8 ) -> str :
0 commit comments