-
Notifications
You must be signed in to change notification settings - Fork 126
Expand file tree
/
Copy pathplugin_set_composition.py
More file actions
155 lines (129 loc) · 4.66 KB
/
Copy pathplugin_set_composition.py
File metadata and controls
155 lines (129 loc) · 4.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# pytest: ollama, e2e
#
# PluginSet composition — group hooks by concern and register them together.
#
# This example shows how to:
# 1. Define hooks across different concerns (security, observability)
# 2. Group them into PluginSets
# 3. Register observability globally and security per-session
#
# Run:
# uv run python docs/examples/plugins/plugin_set_composition.py
import logging
import sys
from mellea import start_session
from mellea.plugins import (
HookType,
PluginMode,
PluginSet,
PluginViolationError,
block,
hook,
register,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("mellea").setLevel(logging.ERROR)
log = logging.getLogger("plugin_set")
# --- Security hooks ---
TOKEN_BUDGET = 4000
@hook(HookType.GENERATION_PRE_CALL, mode=PluginMode.SEQUENTIAL, priority=10)
async def enforce_token_budget(payload, ctx):
"""Enforce a conservative token budget."""
# Rough token estimate: ~4 chars per token
prompt_chars = sum(
len(str(c.format_for_llm()))
for c in (payload.context.view_for_generation() or [])
) + len(
str(
payload.action.format_for_llm()
if hasattr(payload.action, "format_for_llm")
else payload.action
)
)
estimated = prompt_chars // 4 or 0
log.info("[security/token-budget] estimated=%d budget=%d", estimated, TOKEN_BUDGET)
if estimated > TOKEN_BUDGET:
return block(
f"Estimated {estimated} tokens exceeds budget of {TOKEN_BUDGET}",
code="TOKEN_BUDGET_001",
)
@hook(HookType.COMPONENT_PRE_EXECUTE, mode=PluginMode.SEQUENTIAL, priority=20)
async def enforce_description_length(payload, ctx):
"""Reject component actions whose text representation is too long."""
max_len = 2000
action_text = str(payload.action._description) if payload.action else ""
if len(action_text) > max_len:
log.info("[security/desc-length] BLOCKED: action is %d chars", len(action_text))
return block(
f"Action text exceeds {max_len} characters", code="DESC_LENGTH_001"
)
log.info("[security/desc-length] action length OK (%d chars)", len(action_text))
# --- Observability hooks ---
@hook(HookType.SESSION_POST_INIT, mode=PluginMode.AUDIT)
async def trace_session_start(payload, ctx):
"""Trace session initialization."""
log.info(
"[observability/trace] session started (session_id=%s)", payload.session_id
)
@hook(HookType.COMPONENT_POST_SUCCESS, mode=PluginMode.AUDIT)
async def trace_component_success(payload, ctx):
"""Trace successful component executions."""
log.info(
"[observability/trace] %s completed in %dms",
payload.component_type,
payload.latency_ms,
)
@hook(HookType.SESSION_CLEANUP, mode=PluginMode.AUDIT)
async def trace_session_end(payload, ctx):
"""Trace session cleanup."""
log.info(
"[observability/trace] session cleanup (interactions=%d)",
payload.interaction_count,
)
# --- Compose into PluginSets ---
security = PluginSet("security", [enforce_token_budget, enforce_description_length])
observability = PluginSet(
"observability", [trace_session_start, trace_component_success, trace_session_end]
)
if __name__ == "__main__":
log.info("--- PluginSet composition example ---")
log.info("")
# Register observability globally — fires for all sessions
register(observability)
log.info("Registered observability plugins globally")
log.info("")
# Session with security plugins (session-scoped) + global observability
log.info("=== Session with security + observability ===")
with start_session(plugins=[security]) as m:
try:
result = m.instruct("Name three prime numbers.")
log.info("Result: %s", result)
except PluginViolationError as e:
log.error(
"Execution blocked on %s: [%s] %s (plugin=%s)",
e.hook_type,
e.code,
e.reason,
e.plugin_name,
)
sys.exit(1)
log.info("")
log.info("=== Session with observability only ===")
with start_session() as m:
try:
result = m.instruct("What is 2 + 2?")
log.info("Result: %s", result)
except PluginViolationError as e:
log.warning(
"Execution blocked on %s: [%s] %s (plugin=%s)",
e.hook_type,
e.code,
e.reason,
e.plugin_name,
)
sys.exit(1)