State Management
This document explains the state management system in Atlas, which provides structured state models for LangGraph workflows.
Overview
The state management system in Atlas provides:
- Structured State Models: Pydantic models for representing workflow state
- Type Safety: Type hints and validation for state data
- Message History: Standardized conversation history tracking
- Context Management: Storage for retrieved knowledge and metadata
- Worker Coordination: State patterns for parallel agent workflows
The system is designed to be:
- Consistent: Provide uniform state access patterns
- Extensible: Support custom state attributes
- Typesafe: Leverage Pydantic’s type validation
- Compatible: Integrate seamlessly with LangGraph
Core Components
Base Types and Classes
The state management system starts with basic type definitions:
class Message(TypedDict):
"""Message in the conversation."""
role: str
content: str
class Document(TypedDict):
"""Document from the knowledge base."""
content: str
metadata: Dict[str, Any]
relevance_score: float
class Context(TypedDict):
"""Context for the agent."""
documents: List[Document]
query: str
These types provide standardized structures for:
- Message: Conversation messages with role and content
- Document: Knowledge documents with content, metadata, and relevance
- Context: Container for retrieved documents and query information
Worker Configuration
For specialized agents, the system defines the WorkerConfig
class:
class WorkerConfig(BaseModel):
"""Configuration for a worker agent."""
worker_id: str = Field(description="Unique identifier for the worker")
specialization: str = Field(description="What this worker specializes in")
system_prompt: str = Field(description="System prompt for this worker")
This enables standardized configuration of worker agents with:
- Identification: Unique worker IDs
- Specialization: Worker roles and capabilities
- Customization: Worker-specific system prompts
AgentState
The primary state model for individual agents is AgentState
:
class AgentState(BaseModel):
"""State for a LangGraph agent."""
# Basic state
messages: List[Message] = Field(
default_factory=list, description="Conversation history"
)
context: Optional[Context] = Field(
default=None, description="Retrieved context information"
)
# Worker agent state (for parallel processing)
worker_id: Optional[str] = Field(
default=None, description="ID of the current worker (if any)"
)
worker_results: Dict[str, Any] = Field(
default_factory=dict, description="Results from worker agents"
)
worker_configs: List[WorkerConfig] = Field(
default_factory=list, description="Configurations for worker agents"
)
# Flags
process_complete: bool = Field(
default=False, description="Whether processing is complete"
)
error: Optional[str] = Field(default=None, description="Error message if any")
The AgentState
class maintains:
- Conversation History: List of user and assistant messages
- Retrieved Knowledge: Contextual information and documents
- Worker Metadata: ID and results for parallel processing
- Status Flags: Processing completion and error states
ControllerState
For multi-agent orchestration, the system defines the ControllerState
class:
class ControllerState(BaseModel):
"""State for a controller agent managing multiple workers."""
# Main state
messages: List[Message] = Field(
default_factory=list, description="Main conversation history"
)
context: Optional[Context] = Field(
default=None, description="Retrieved context information"
)
# Worker management
workers: Dict[str, AgentState] = Field(
default_factory=dict, description="States for all workers"
)
active_workers: List[str] = Field(
default_factory=list, description="Currently active worker IDs"
)
completed_workers: List[str] = Field(
default_factory=list, description="IDs of workers that have completed"
)
# Task tracking
tasks: List[Dict[str, Any]] = Field(
default_factory=list, description="Tasks to be processed"
)
results: List[Dict[str, Any]] = Field(
default_factory=list, description="Results from completed tasks"
)
# Flags
all_tasks_assigned: bool = Field(
default=False, description="Whether all tasks have been assigned"
)
all_tasks_completed: bool = Field(
default=False, description="Whether all tasks have been completed"
)
The ControllerState
class manages:
- Global Conversation: User-facing conversation history
- Worker Registry: Tracking multiple worker agents
- Task Management: Distribution and collection of tasks
- Completion Status: Assignment and completion flags
Integration with LangGraph
State Graph Initialization
The state models integrate with LangGraph’s StateGraph
:
from langgraph.graph import StateGraph
from atlas.graph.state import AgentState, ControllerState
# Create a graph with AgentState
basic_graph = StateGraph(AgentState)
# Create a graph with ControllerState for multi-agent workflows
controller_graph = StateGraph(ControllerState)
This ensures that:
- Type Safety: Graph nodes work with properly typed state
- Validation: State transitions validate against the model
- Documentation: State fields are self-documenting via descriptions
Node Functions
Node functions in LangGraph receive and return the state:
def retrieve_knowledge(state: AgentState, config: Optional[AtlasConfig] = None) -> AgentState:
"""Retrieve knowledge from the Atlas knowledge base."""
# Initialize knowledge base
kb = KnowledgeBase(collection_name=cfg.collection_name, db_path=cfg.db_path)
# Extract query from state
# ...
# Update state with retrieved documents
state.context = {"documents": documents, "query": query}
return state
Conditional Edges
State fields are used for graph routing decisions:
# Add conditional edge based on state
builder.add_conditional_edges(
"generate_response",
should_end, # Function that examines state.process_complete
{True: END, False: "retrieve_knowledge"}
)
State Management Patterns
Conversation State Management
The AgentState
maintains a conversation history:
# Initialize state with user message
initial_state = AgentState(messages=[{"role": "user", "content": "Hello"}])
# Add assistant response in a node function
def add_response(state: AgentState) -> AgentState:
# Generate response using LLM
response = "..."
# Add to conversation history
state.messages.append({"role": "assistant", "content": response})
return state
Context Management
Retrieved knowledge is stored in the context field:
# Store retrieved documents in state context
def store_context(state: AgentState, documents: List[Document]) -> AgentState:
state.context = {
"documents": documents,
"query": "original query"
}
return state
# Access context in another node
def use_context(state: AgentState) -> AgentState:
if state.context and state.context["documents"]:
documents = state.context["documents"]
# Use documents...
return state
Error Handling
State includes error tracking:
# Handle errors in node function
def node_with_error_handling(state: AgentState) -> AgentState:
try:
# Potentially risky operation...
result = api_call()
except Exception as e:
# Record error in state
state.error = f"API error: {str(e)}"
return state
# Check for errors in conditional edge function
def check_for_errors(state: AgentState) -> bool:
return state.error is not None
Worker Orchestration
In multi-agent workflows, the ControllerState
manages worker coordination:
# Create tasks for workers
def create_worker_tasks(state: ControllerState) -> ControllerState:
# Create tasks based on user query
# ...
# Add tasks to state
state.tasks = tasks
# Initialize worker states
for task in tasks:
worker_id = task["worker_id"]
worker_state = AgentState(
worker_id=worker_id,
messages=[{"role": "user", "content": query}]
)
state.workers[worker_id] = worker_state
state.active_workers.append(worker_id)
state.all_tasks_assigned = True
return state
# Process results from workers
def process_worker_results(state: ControllerState) -> ControllerState:
combined_results = []
for worker_id in state.completed_workers:
worker_state = state.workers.get(worker_id)
if worker_state:
# Extract results from worker state
# ...
combined_results.append({"worker_id": worker_id, "content": result})
state.results = combined_results
state.all_tasks_completed = True
return state
Usage Examples
Basic Agent State
from atlas.graph.state import AgentState
from atlas.graph.workflows import create_basic_rag_graph
# Create initial state
initial_state = AgentState(
messages=[{"role": "user", "content": "What is the trimodal methodology?"}]
)
# Create and run graph
graph = create_basic_rag_graph()
final_state = graph.invoke(initial_state)
# Extract assistant's response from final state
assistant_response = final_state.messages[-1]["content"]
print(f"Response: {assistant_response}")
Multi-Agent Controller State
from atlas.graph.state import ControllerState
from atlas.graph.workflows import create_controller_worker_graph
# Create initial state
initial_state = ControllerState(
messages=[{"role": "user", "content": "Explain knowledge graphs in Atlas"}]
)
# Create and run graph
graph = create_controller_worker_graph()
final_state = graph.invoke(initial_state)
# Access worker results
print(f"Completed workers: {final_state.completed_workers}")
for worker_id, result in final_state.results:
print(f"Results from {worker_id}: {result[:100]}...")
# Get final response
final_response = final_state.messages[-1]["content"]
print(f"Final response: {final_response}")
Creating Custom State
You can extend the base state models for specialized workflows:
from pydantic import BaseModel, Field
from atlas.graph.state import AgentState
class CustomDocumentState(BaseModel):
document_id: str
processed: bool = False
summary: Optional[str] = None
class DocumentProcessingState(AgentState):
"""Extended state for document processing workflows."""
documents_to_process: List[CustomDocumentState] = Field(
default_factory=list,
description="Documents pending processing"
)
processed_documents: List[CustomDocumentState] = Field(
default_factory=list,
description="Documents that have been processed"
)
current_document_id: Optional[str] = Field(
default=None,
description="ID of the document currently being processed"
)
# Use in a LangGraph workflow
from langgraph.graph import StateGraph
graph = StateGraph(DocumentProcessingState)
# Add nodes and edges...
Advanced Patterns
State Versioning
For long-running workflows, state versioning can be implemented:
# Add version tracking to state
class VersionedAgentState(AgentState):
version: int = Field(default=1, description="State schema version")
state_history: List[Dict[str, Any]] = Field(
default_factory=list,
description="History of state transitions"
)
def snapshot(self) -> None:
"""Create a snapshot of the current state."""
snapshot = self.dict()
del snapshot["state_history"] # Don't include history in history
self.state_history.append(snapshot)
State Validation
Use Pydantic validation for custom state constraints:
from pydantic import validator
class ValidatedAgentState(AgentState):
max_history_length: int = Field(default=50, description="Maximum history length")
@validator("messages")
def validate_messages_length(cls, v, values):
max_length = values.get("max_history_length", 50)
if len(v) > max_length:
# Truncate to the most recent messages
return v[-max_length:]
return v
State Transformations
Create utility functions for common state transformations:
def add_user_message(state: AgentState, message: str) -> AgentState:
"""Add a user message to the state."""
state.messages.append({"role": "user", "content": message})
return state
def add_assistant_message(state: AgentState, message: str) -> AgentState:
"""Add an assistant message to the state."""
state.messages.append({"role": "assistant", "content": message})
return state
def clear_context(state: AgentState) -> AgentState:
"""Clear the context from the state."""
state.context = None
return state
Best Practices
State Design Principles
- Single Source of Truth: Keep all related data in one state object
- Immutability: Treat state as immutable, return a new/updated state
- Minimal State: Include only necessary information in the state
- Type Safety: Use type hints and validation for all state fields
- Self-Documentation: Include clear field descriptions
Performance Considerations
For large state objects:
- Selective Updates: Only update fields that have changed
- Pruning: Remove unnecessary data from the state
- Lazy Loading: Defer loading large data until needed
- Serialization: Consider serialization efficiency for large objects
Error Handling
Robust error handling with state:
- Error Fields: Use dedicated fields for error information
- Validation: Validate state before processing
- Recovery: Include enough information to recover from errors
- Logging: Log state transitions and errors
Related Documentation
- Graph Nodes - Documentation for graph node functions
- Graph Edges - Documentation for conditional edge routing