-
Notifications
You must be signed in to change notification settings - Fork 895
/
Copy pathinput_guardrails.py
105 lines (82 loc) · 3.06 KB
/
input_guardrails.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations
import asyncio
from pydantic import BaseModel
from agents import (
Agent,
GuardrailFunctionOutput,
InputGuardrailTripwireTriggered,
RunContextWrapper,
Runner,
TResponseInputItem,
input_guardrail,
)
"""
This example shows how to use guardrails.
Guardrails are checks that run in parallel to the agent's execution.
They can be used to do things like:
- Check if input messages are off-topic
- Check that output messages don't violate any policies
- Take over control of the agent's execution if an unexpected input is detected
In this example, we'll setup an input guardrail that trips if the user is asking to do math homework.
If the guardrail trips, we'll respond with a refusal message.
"""
### 1. An agent-based guardrail that is triggered if the user is asking to do math homework
class MathHomeworkOutput(BaseModel):
reasoning: str
is_math_homework: bool
guardrail_agent = Agent(
name="Guardrail check",
instructions="Check if the user is asking you to do their math homework.",
output_type=MathHomeworkOutput,
)
@input_guardrail
async def math_guardrail(
context: RunContextWrapper[None], agent: Agent, input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
"""This is an input guardrail function, which happens to call an agent to check if the input
is a math homework question.
"""
result = await Runner.run(guardrail_agent, input, context=context.context)
final_output = result.final_output_as(MathHomeworkOutput)
return GuardrailFunctionOutput(
output_info=final_output,
tripwire_triggered=final_output.is_math_homework,
)
### 2. The run loop
async def main():
agent = Agent(
name="Customer support agent",
instructions="You are a customer support agent. You help customers with their questions.",
input_guardrails=[math_guardrail],
)
input_data: list[TResponseInputItem] = []
while True:
user_input = input("Enter a message: ")
input_data.append(
{
"role": "user",
"content": user_input,
}
)
try:
result = await Runner.run(agent, input_data)
print(result.final_output)
# If the guardrail didn't trigger, we use the result as the input for the next run
input_data = result.to_input_list()
except InputGuardrailTripwireTriggered:
# If the guardrail triggered, we instead add a refusal message to the input
message = "Sorry, I can't help you with your math homework."
print(message)
input_data.append(
{
"role": "assistant",
"content": message,
}
)
# Sample run:
# Enter a message: What's the capital of California?
# The capital of California is Sacramento.
# Enter a message: Can you help me solve for x: 2x + 5 = 11
# Sorry, I can't help you with your math homework.
if __name__ == "__main__":
asyncio.run(main())