Skip to main content

Agent Middleware

Middleware system for processing requests and responses in agent workflows, providing safety checks, cost tracking, rate limiting, and other cross-cutting concerns.

Core Classes

GuardrailsMiddleware

Description: Middleware for input/output filtering and safety checks

Parameters:

  • policy_engine (PolicyEngine, optional): Policy engine for safety checks
  • enable_input_filtering (bool): Enable input filtering
  • enable_output_filtering (bool): Enable output filtering
  • block_on_violation (bool): Block requests on policy violations

Returns: GuardrailsMiddleware instance

Example:

from recoagent.agents.middleware import GuardrailsMiddleware
from recoagent.agents.policies import PolicyEngine, SafetyPolicy

# Create policy engine
policy_engine = PolicyEngine([
SafetyPolicy(blocked_patterns=[r"(?i)(harmful|illegal)"])
])

# Create guardrails middleware
guardrails = GuardrailsMiddleware(
policy_engine=policy_engine,
enable_input_filtering=True,
enable_output_filtering=True,
block_on_violation=True
)

# Use with agent
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[guardrails]
)

CostTrackingMiddleware

Description: Middleware for tracking and limiting costs

Parameters:

  • max_cost_per_request (float): Maximum cost per request
  • max_cost_per_hour (float): Maximum cost per hour
  • cost_tracking_enabled (bool): Enable cost tracking
  • alert_threshold (float): Cost alert threshold

Returns: CostTrackingMiddleware instance

Example:

from recoagent.agents.middleware import CostTrackingMiddleware

# Create cost tracking middleware
cost_tracker = CostTrackingMiddleware(
max_cost_per_request=0.10,
max_cost_per_hour=5.00,
cost_tracking_enabled=True,
alert_threshold=0.80
)

# Use with agent
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[cost_tracker]
)

MiddlewareContext

Description: Context object passed through middleware chain

Fields:

  • user_id (str, optional): User identifier
  • session_id (str, optional): Session identifier
  • request_id (str, optional): Request identifier
  • timestamp (datetime): Request timestamp
  • metadata (Dict): Additional metadata

Usage Examples

Basic Middleware Setup

from recoagent.agents.middleware import GuardrailsMiddleware, CostTrackingMiddleware
from recoagent.agents import RAGAgentGraph

# Create middleware components
guardrails = GuardrailsMiddleware()
cost_tracker = CostTrackingMiddleware(max_cost_per_request=0.05)

# Create agent with middleware
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[guardrails, cost_tracker]
)

# Middleware processes requests automatically
result = agent.invoke({"query": "What is machine learning?"})

Advanced Safety Middleware

from recoagent.agents.middleware import GuardrailsMiddleware
from recoagent.agents.policies import PolicyEngine, SafetyPolicy, ToolPolicy

# Create comprehensive policy engine
policy_engine = PolicyEngine([
SafetyPolicy(
blocked_patterns=[
r"(?i)(harmful|dangerous|illegal)",
r"(?i)(personal\s+information|credit\s+card)"
],
max_response_length=1000
),
ToolPolicy(
allowed_tools={"retrieval", "web_search"},
tool_usage_limits={"web_search": 3}
)
])

# Create advanced guardrails
guardrails = GuardrailsMiddleware(
policy_engine=policy_engine,
enable_input_filtering=True,
enable_output_filtering=True,
block_on_violation=True
)

# Use with agent
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[guardrails]
)

Cost Management Middleware

from recoagent.agents.middleware import CostTrackingMiddleware

# Create cost management middleware
cost_middleware = CostTrackingMiddleware(
max_cost_per_request=0.10,
max_cost_per_hour=10.00,
cost_tracking_enabled=True,
alert_threshold=0.80
)

# Use with agent
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[cost_middleware]
)

# Execute queries with cost tracking
result = agent.invoke({"query": "Complex query"})

# Check cost information
print(f"Request cost: ${cost_middleware.get_request_cost():.4f}")
print(f"Hourly cost: ${cost_middleware.get_hourly_cost():.4f}")

Custom Middleware Implementation

from recoagent.agents.middleware import BaseMiddleware, MiddlewareContext
from typing import Dict, Any

class CustomLoggingMiddleware(BaseMiddleware):
"""Custom middleware for detailed logging."""

def __init__(self, log_level: str = "INFO"):
self.log_level = log_level
self.logger = structlog.get_logger()

async def process_request(self, context: MiddlewareContext, request: Dict[str, Any]) -> Dict[str, Any]:
"""Log incoming request."""
self.logger.info(
"Processing request",
user_id=context.user_id,
request_id=context.request_id,
query=request.get("query", ""),
timestamp=context.timestamp
)

# Add custom metadata
request["custom_metadata"] = {
"processed_by": "custom_middleware",
"timestamp": context.timestamp.isoformat()
}

return request

async def process_response(self, context: MiddlewareContext, response: Dict[str, Any]) -> Dict[str, Any]:
"""Log outgoing response."""
self.logger.info(
"Processing response",
user_id=context.user_id,
request_id=context.request_id,
response_length=len(str(response.get("answer", ""))),
timestamp=context.timestamp
)

return response

# Create custom middleware
logging_middleware = CustomLoggingMiddleware(log_level="DEBUG")

# Use with other middleware
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=[logging_middleware, guardrails, cost_tracker]
)

Multi-Middleware Chain

from recoagent.agents.middleware import (
GuardrailsMiddleware,
CostTrackingMiddleware,
RateLimitMiddleware
)

# Create middleware chain
middleware_chain = [
RateLimitMiddleware(requests_per_minute=60),
GuardrailsMiddleware(),
CostTrackingMiddleware(max_cost_per_request=0.10),
CustomLoggingMiddleware()
]

# Create agent with full middleware chain
agent = RAGAgentGraph(
config=config,
llm=llm,
tools=tools,
middleware=middleware_chain
)

# All middleware processes requests in order
result = agent.invoke({
"query": "What is artificial intelligence?",
"user_id": "user_123",
"session_id": "session_456"
})

Middleware Context Management

from recoagent.agents.middleware import MiddlewareContext
from datetime import datetime

# Create context for request
context = MiddlewareContext(
user_id="user_123",
session_id="session_456",
request_id="req_789",
timestamp=datetime.utcnow(),
metadata={
"source": "web_interface",
"user_tier": "premium"
}
)

# Pass context through middleware
result = agent.invoke_with_context(
context=context,
request={"query": "Premium query"}
)

# Access context information
print(f"User: {context.user_id}")
print(f"Session: {context.session_id}")
print(f"Metadata: {context.metadata}")

API Reference

BaseMiddleware Methods

process_request(context: MiddlewareContext, request: Dict) -> Dict

Process incoming request

Parameters:

  • context (MiddlewareContext): Request context
  • request (Dict): Request data

Returns: Modified request data

process_response(context: MiddlewareContext, response: Dict) -> Dict

Process outgoing response

Parameters:

  • context (MiddlewareContext): Request context
  • response (Dict): Response data

Returns: Modified response data

GuardrailsMiddleware Methods

check_input_safety(input_data: Dict) -> bool

Check input for safety violations

Parameters:

  • input_data (Dict): Input data to check

Returns: True if safe

check_output_safety(output_data: Dict) -> bool

Check output for safety violations

Parameters:

  • output_data (Dict): Output data to check

Returns: True if safe

CostTrackingMiddleware Methods

get_request_cost() -> float

Get cost for current request

Returns: Cost in USD

get_hourly_cost() -> float

Get total cost for current hour

Returns: Cost in USD

is_cost_limit_exceeded() -> bool

Check if cost limits are exceeded

Returns: True if limits exceeded

reset_hourly_cost() -> None

Reset hourly cost tracking

MiddlewareContext Methods

add_metadata(key: str, value: Any) -> None

Add metadata to context

Parameters:

  • key (str): Metadata key
  • value (Any): Metadata value

get_metadata(key: str) -> Any

Get metadata from context

Parameters:

  • key (str): Metadata key

Returns: Metadata value

See Also