-
Notifications
You must be signed in to change notification settings - Fork 895
/
Copy pathoutput_guardrails.py
80 lines (61 loc) · 2.3 KB
/
output_guardrails.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from __future__ import annotations
import asyncio
import json
from pydantic import BaseModel, Field
from agents import (
Agent,
GuardrailFunctionOutput,
OutputGuardrailTripwireTriggered,
RunContextWrapper,
Runner,
output_guardrail,
)
"""
This example shows how to use output guardrails.
Output guardrails are checks that run on the final output of an agent.
They can be used to do things like:
- Check if the output contains sensitive data
- Check if the output is a valid response to the user's message
In this example, we'll use a (contrived) example where we check if the agent's response contains
a phone number.
"""
# The agent's output type
class MessageOutput(BaseModel):
reasoning: str = Field(description="Thoughts on how to respond to the user's message")
response: str = Field(description="The response to the user's message")
user_name: str | None = Field(description="The name of the user who sent the message, if known")
@output_guardrail
async def sensitive_data_check(
context: RunContextWrapper, agent: Agent, output: MessageOutput
) -> GuardrailFunctionOutput:
phone_number_in_response = "650" in output.response
phone_number_in_reasoning = "650" in output.reasoning
return GuardrailFunctionOutput(
output_info={
"phone_number_in_response": phone_number_in_response,
"phone_number_in_reasoning": phone_number_in_reasoning,
},
tripwire_triggered=phone_number_in_response or phone_number_in_reasoning,
)
agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
output_type=MessageOutput,
output_guardrails=[sensitive_data_check],
)
async def main():
# This should be ok
await Runner.run(agent, "What's the capital of California?")
print("First message passed")
# This should trip the guardrail
try:
result = await Runner.run(
agent, "My phone number is 650-123-4567. Where do you think I live?"
)
print(
f"Guardrail didn't trip - this is unexpected. Output: {json.dumps(result.final_output.model_dump(), indent=2)}"
)
except OutputGuardrailTripwireTriggered as e:
print(f"Guardrail tripped. Info: {e.guardrail_result.output.output_info}")
if __name__ == "__main__":
asyncio.run(main())