Writing Thread-Safe Adapters for Snakepit
View SourceGuide Version: 1.0 Date: 2025-10-11 Snakepit Version: v0.6.0+
Table of Contents
- Overview
- Prerequisites
- Thread Safety Fundamentals
- Three Safety Patterns
- Step-by-Step Tutorial
- Common Pitfalls
- Testing Strategies
- Library Compatibility
- Advanced Topics
- Debugging Guide
- Best Practices
- Examples
Overview
What is a Thread-Safe Adapter?
A thread-safe adapter can handle multiple concurrent requests without data corruption, race conditions, or undefined behavior. This guide teaches you how to write Python adapters that work correctly with Snakepit's multi-threaded worker profile.
Why Thread Safety Matters
# ❌ NOT thread-safe
class UnsafeAdapter:
def __init__(self):
self.counter = 0
def process(self, data):
self.counter += 1 # RACE CONDITION!
return {"count": self.counter}
# ✅ Thread-safe
class SafeAdapter(ThreadSafeAdapter):
def __init__(self):
super().__init__()
self.counter = 0
def process(self, data):
with self.acquire_lock():
self.counter += 1
return {"count": self.counter}When You Need This Guide
- ✅ Building custom Python adapters for Snakepit
- ✅ Using thread worker profile (
:thread) - ✅ Python 3.13+ with free-threading enabled
- ✅ CPU-intensive workloads requiring parallelism
Prerequisites
Required Knowledge
- Python: Functions, classes, decorators
- Concurrency: Basic understanding of threads
- Snakepit: Adapter pattern, worker profiles
Required Software
# Python 3.13+ with free-threading
python3.13 --version
# => Python 3.13.0
# Verify free-threading support
python3.13 -c "import sys; print(hasattr(sys, '_is_gil_enabled'))"
# => True
# Snakepit v0.6.0+
mix deps | grep snakepit
# => * snakepit 0.6.0
Test Environment Setup
# Create virtual environment
python3.13 -m venv .venv
source .venv/bin/activate
# Install dependencies
pip install grpcio protobuf numpy pytest pytest-xdist
# Install Snakepit Python bridge
pip install -e deps/snakepit/priv/python
Thread Safety Fundamentals
The Three Rules of Thread Safety
Immutable shared state is safe
- Read-only data can be accessed concurrently
- Examples: Pre-loaded models, config dicts
Mutable shared state requires locking
- If data can be modified, protect with locks
- Examples: Counters, logs, caches
Thread-local storage is safe
- Data isolated per thread doesn't need locks
- Examples: Per-thread caches, buffers
Thread Safety Checklist
When reviewing your adapter, ask:
- [ ] Does this method modify shared state?
- [ ] Is this data structure accessed from multiple threads?
- [ ] Does this library release the GIL?
- [ ] Are there any race conditions?
- [ ] Is error handling thread-safe?
Three Safety Patterns
Pattern 1: Shared Read-Only Resources
When to use: Data loaded once, never modified
from snakepit_bridge.base_adapter_threaded import ThreadSafeAdapter
class ModelAdapter(ThreadSafeAdapter):
__thread_safe__ = True # Required declaration
def __init__(self):
super().__init__()
# Pattern 1: Shared read-only (NO LOCK NEEDED)
self.model = self._load_model()
self.config = {"timeout": 30, "batch_size": 10}
def _load_model(self):
"""Load model once, shared across threads"""
import torch
model = torch.load("model.pt")
model.eval() # Set to evaluation mode
return model
@thread_safe_method
def predict(self, input_data):
# Safe: model is read-only
# PyTorch releases GIL during forward pass
with torch.no_grad():
output = self.model(torch.tensor(input_data))
return output.tolist()Why it's safe:
- Model loaded once in
__init__ - Never modified after loading
- PyTorch
.forward()releases GIL - Multiple threads can read concurrently
Pattern 2: Thread-Local Storage
When to use: Per-thread state (caches, buffers, connections)
class CachingAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
self.model = load_model() # Shared read-only
@thread_safe_method
def compute(self, key, data):
# Pattern 2: Thread-local storage (NO LOCK NEEDED)
# Each thread has its own cache
cache = self.get_thread_local('cache', {})
if key in cache:
return cache[key]
# Compute result
result = self.model.predict(data)
# Update thread-local cache
cache[key] = result
self.set_thread_local('cache', cache)
return resultWhy it's safe:
- Each thread has isolated
cachedict - No sharing between threads
- No race conditions possible
- Excellent performance (no locks)
Pattern 3: Locked Shared Mutable State
When to use: State that must be shared AND modified
class CountingAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
self.model = load_model() # Pattern 1
# Pattern 3: Shared mutable (REQUIRES LOCK)
self.total_requests = 0
self.request_log = []
@thread_safe_method
def process(self, data):
# Compute first (NO LOCK - allows parallelism)
result = self.model.predict(data)
# THEN lock for state update (BRIEF LOCK)
with self.acquire_lock():
self.total_requests += 1
self.request_log.append({
"result": result,
"timestamp": time.time()
})
return result
@thread_safe_method
def get_stats(self):
# Pattern 3: Read shared mutable (REQUIRES LOCK)
with self.acquire_lock():
return {
"total_requests": self.total_requests,
"log_size": len(self.request_log)
}Why it's safe:
- Compute happens WITHOUT lock (parallel)
- Lock held only for state update (fast)
- Both reads and writes protected
- No race conditions
Step-by-Step Tutorial
Step 1: Create Thread-Safe Adapter Class
# my_adapter.py
from snakepit_bridge.base_adapter_threaded import (
ThreadSafeAdapter,
thread_safe_method,
tool
)
import numpy as np
class MyAdapter(ThreadSafeAdapter):
"""Example thread-safe adapter"""
# Step 1.1: Declare thread safety
__thread_safe__ = True
def __init__(self):
# Step 1.2: Call parent constructor
super().__init__()
# Step 1.3: Initialize resources
# Pattern 1: Shared read-only
self.model = self._load_model()
# Pattern 3: Shared mutable
self.request_count = 0Step 2: Implement Thread-Safe Methods
@thread_safe_method
@tool(description="Compute with NumPy")
def compute(self, data: list) -> dict:
"""
Thread-safe computation.
Uses Pattern 1 (shared model) + Pattern 3 (shared counter).
"""
# Convert to NumPy array (thread-safe)
arr = np.array(data)
# NumPy computation (releases GIL - parallel!)
result = np.dot(arr, arr.T)
# Update shared state (lock required)
with self.acquire_lock():
self.request_count += 1
count = self.request_count
return {
"result": result.tolist(),
"request_number": count
}Step 3: Add Thread-Local Caching
@thread_safe_method
@tool(description="Compute with caching")
def compute_cached(self, key: str, data: list) -> dict:
"""
Thread-safe computation with per-thread cache.
Uses Pattern 2 (thread-local storage).
"""
# Check thread-local cache first
cache = self.get_thread_local('cache', {})
if key in cache:
return {
"result": cache[key],
"cached": True
}
# Compute
arr = np.array(data)
result = np.dot(arr, arr.T).tolist()
# Update thread-local cache
cache[key] = result
self.set_thread_local('cache', cache)
# Update shared counter
with self.acquire_lock():
self.request_count += 1
return {
"result": result,
"cached": False
}Step 4: Add Statistics Method
@thread_safe_method
@tool(description="Get adapter statistics")
def get_stats(self) -> dict:
"""
Thread-safe statistics.
Reads shared mutable state (Pattern 3).
"""
with self.acquire_lock():
stats = self.get_stats_dict()
stats['total_requests'] = self.request_count
return statsStep 5: Test Thread Safety
# test_my_adapter.py
import pytest
import threading
from concurrent.futures import ThreadPoolExecutor
def test_concurrent_compute(my_adapter):
"""Test concurrent access to compute method"""
results = []
errors = []
def make_request(i):
try:
result = my_adapter.compute([1, 2, 3, 4, 5])
results.append(result)
except Exception as e:
errors.append(e)
# Hammer with 100 concurrent requests
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(make_request, i) for i in range(100)]
for future in futures:
future.result(timeout=10)
# All should succeed
assert len(results) == 100
assert len(errors) == 0
# Request count should be exactly 100
stats = my_adapter.get_stats()
assert stats['total_requests'] == 100Common Pitfalls
Pitfall 1: Forgetting to Lock Shared State
# ❌ WRONG: Race condition
class BadAdapter(ThreadSafeAdapter):
def __init__(self):
super().__init__()
self.counter = 0
@thread_safe_method
def increment(self):
self.counter += 1 # NOT ATOMIC!
return self.counter
# ✅ CORRECT: Lock protected
class GoodAdapter(ThreadSafeAdapter):
def __init__(self):
super().__init__()
self.counter = 0
@thread_safe_method
def increment(self):
with self.acquire_lock():
self.counter += 1
return self.counterProblem: self.counter += 1 is three operations:
- Read
self.counter - Add 1
- Write back
Between steps 1-3, another thread can modify counter.
Pitfall 2: Holding Lock During Expensive Operations
# ❌ WRONG: Lock held during computation
@thread_safe_method
def process(self, data):
with self.acquire_lock():
arr = np.array(data)
result = np.dot(arr, arr.T) # EXPENSIVE!
self.results.append(result)
return result
# ✅ CORRECT: Minimize lock duration
@thread_safe_method
def process(self, data):
# Compute WITHOUT lock
arr = np.array(data)
result = np.dot(arr, arr.T)
# THEN lock briefly
with self.acquire_lock():
self.results.append(result)
return resultRule: Only hold locks for the minimum time needed.
Pitfall 3: Using Thread-Unsafe Libraries
# ❌ WRONG: Pandas is NOT thread-safe
import pandas as pd
class BadAdapter(ThreadSafeAdapter):
def __init__(self):
super().__init__()
self.df = pd.DataFrame() # Shared DataFrame
@thread_safe_method
def add_row(self, data):
# RACE CONDITION! Even with lock around DataFrame
self.df = self.df.append(data, ignore_index=True)
# ✅ CORRECT: Use thread-safe alternatives
import polars as pl
class GoodAdapter(ThreadSafeAdapter):
def __init__(self):
super().__init__()
self.rows = [] # Collect rows
@thread_safe_method
def add_row(self, data):
with self.acquire_lock():
self.rows.append(data)
@thread_safe_method
def get_dataframe(self):
with self.acquire_lock():
return pl.DataFrame(self.rows)Solution: Use Polars instead of Pandas, or lock ALL DataFrame operations.
Pitfall 4: Missing thread_safe Declaration
# ❌ WRONG: No declaration
class BadAdapter(ThreadSafeAdapter):
# Missing __thread_safe__ = True
# ✅ CORRECT: Always declare
class GoodAdapter(ThreadSafeAdapter):
__thread_safe__ = TrueWhy: Runtime checker validates thread safety when declared.
Pitfall 5: Deadlocks
# ❌ WRONG: Potential deadlock
@thread_safe_method
def method_a(self):
with self.acquire_lock():
return self.method_b() # Tries to acquire same lock!
@thread_safe_method
def method_b(self):
with self.acquire_lock():
return "result"
# ✅ CORRECT: Use reentrant lock (already provided)
# ThreadSafeAdapter uses RLock (reentrant), so this works:
@thread_safe_method
def method_a(self):
with self.acquire_lock():
return self._method_b_impl()
def _method_b_impl(self):
# Private method, called within lock
return "result"Testing Strategies
Strategy 1: Concurrent Hammer Test
def test_concurrent_hammer():
adapter = MyAdapter()
results = []
def worker(i):
for _ in range(100):
result = adapter.compute([i, i+1, i+2])
results.append(result)
threads = [threading.Thread(target=worker, args=(i,)) for i in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# Should have exactly 1000 results
assert len(results) == 1000
# Counter should be exactly 1000
stats = adapter.get_stats()
assert stats['total_requests'] == 1000Strategy 2: Race Condition Detector
def test_race_condition():
"""
If increment has race condition, final count will be < 10000
"""
adapter = MyAdapter()
def increment_many():
for _ in range(1000):
adapter.increment()
threads = [threading.Thread(target=increment_many) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# If thread-safe, count should be exactly 10000
assert adapter.get_count() == 10000Strategy 3: Thread Safety Checker
from snakepit_bridge.thread_safety_checker import ThreadSafetyChecker
def test_with_checker():
checker = ThreadSafetyChecker(enabled=True, strict_mode=True)
adapter = MyAdapter()
# Run concurrent requests
def worker():
for _ in range(50):
adapter.compute([1, 2, 3])
threads = [threading.Thread(target=worker) for _ in range(20)]
for t in threads:
t.start()
for t in threads:
t.join()
# Get thread safety report
report = checker.get_report()
# Should have no warnings
assert len(report['warnings']) == 0Strategy 4: Load Testing
import pytest
from concurrent.futures import ThreadPoolExecutor
@pytest.mark.benchmark
def test_throughput():
adapter = MyAdapter()
def single_request():
return adapter.compute([1, 2, 3, 4, 5])
# Measure throughput
with ThreadPoolExecutor(max_workers=16) as executor:
start = time.time()
futures = [executor.submit(single_request) for _ in range(1000)]
results = [f.result() for f in futures]
elapsed = time.time() - start
throughput = 1000 / elapsed
print(f"Throughput: {throughput:.2f} req/s")
# All requests should succeed
assert len(results) == 1000Library Compatibility
Thread-Safe Libraries ✅
These libraries work well with threaded adapters:
| Library | Thread-Safe | Notes |
|---|---|---|
| NumPy | ✅ Yes | Releases GIL during computation |
| SciPy | ✅ Yes | Releases GIL for numerical ops |
| PyTorch | ✅ Yes | Configure with torch.set_num_threads() |
| TensorFlow | ✅ Yes | Use tf.config.threading |
| Scikit-learn | ✅ Yes | Set n_jobs=1 per estimator |
| Polars | ✅ Yes | Thread-safe DataFrame library |
| HTTPx | ✅ Yes | Async-first, thread-safe |
| Requests | ✅ Yes | Use separate Session per thread |
Thread-Unsafe Libraries ❌
These require special handling:
| Library | Thread-Safe | Workaround |
|---|---|---|
| Pandas | ❌ No | Use Polars or lock all DataFrame ops |
| Matplotlib | ❌ No | Use threading.local() for figures |
| SQLite3 | ❌ No | Connection per thread with check_same_thread=False |
Example: Thread-Safe NumPy
class NumPyAdapter(ThreadSafeAdapter):
__thread_safe__ = True
@thread_safe_method
def matrix_multiply(self, a_data, b_data):
# NumPy releases GIL - true parallelism!
a = np.array(a_data)
b = np.array(b_data)
result = np.dot(a, b)
return result.tolist()Example: Thread-Safe PyTorch
import torch
class TorchAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
# Load model once (shared read-only)
self.model = torch.load("model.pt")
self.model.eval()
# Configure threading
torch.set_num_threads(4)
@thread_safe_method
def inference(self, input_data):
# PyTorch releases GIL during forward
with torch.no_grad():
tensor = torch.tensor(input_data)
output = self.model(tensor)
return output.tolist()Example: Workaround for Pandas
import pandas as pd
class PandasAdapter(ThreadSafeAdapter):
__thread_safe__ = True
@thread_safe_method
def process_dataframe(self, data):
# Lock ALL Pandas operations
with self.acquire_lock():
df = pd.DataFrame(data)
result = df.groupby('category').sum()
return result.to_dict()Advanced Topics
Topic 1: Custom Locks
class MultiLockAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
# Use separate locks for different resources
import threading
self.counter_lock = threading.Lock()
self.log_lock = threading.Lock()
self.counter = 0
self.log = []
@thread_safe_method
def increment(self):
with self.counter_lock: # Only locks counter
self.counter += 1
@thread_safe_method
def log_event(self, event):
with self.log_lock: # Only locks log
self.log.append(event)When to use: Reduce contention by using separate locks for independent resources.
Topic 2: Lock-Free Data Structures
from queue import Queue # Thread-safe queue
class QueueAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
self.results = Queue() # Lock-free!
@thread_safe_method
def add_result(self, result):
self.results.put(result) # Thread-safe, no lock needed
@thread_safe_method
def get_results(self):
results = []
while not self.results.empty():
results.append(self.results.get())
return resultsTopic 3: Atomic Operations
import threading
class AtomicAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
self.counter = 0
self._lock = threading.Lock()
@thread_safe_method
def atomic_increment(self):
# More efficient than context manager for simple ops
self._lock.acquire()
try:
self.counter += 1
result = self.counter
finally:
self._lock.release()
return resultDebugging Guide
Enable Thread Safety Checks
python grpc_server_threaded.py \
--thread-safety-check # Enable runtime validation
Common Error Messages
Error: "Method accessed by multiple threads"
⚠️ THREAD SAFETY: Method 'predict' accessed by 5 different threads concurrently.Cause: Method modifies shared state without locking.
Solution: Add lock around shared state access.
Error: "Unsafe library detected"
⚠️ THREAD SAFETY: Unsafe library 'pandas' detectedCause: Using thread-unsafe library.
Solution: Switch to thread-safe alternative or add locking.
Error: "Adapter does not declare thread safety"
⚠️ Adapter MyAdapter does not declare thread safety.Cause: Missing __thread_safe__ = True.
Solution: Add declaration to adapter class.
Debugging Tools
# Enable detailed logging
import logging
logging.basicConfig(level=logging.DEBUG)
# Check which thread is running
import threading
print(f"Thread: {threading.current_thread().name}")
# Track lock acquisitions
class DebugAdapter(ThreadSafeAdapter):
@thread_safe_method
def compute(self, data):
print(f"[{threading.current_thread().name}] Acquiring lock...")
with self.acquire_lock():
print(f"[{threading.current_thread().name}] Lock acquired!")
result = do_work(data)
print(f"[{threading.current_thread().name}] Lock released")
return resultBest Practices
Do's ✅
Always declare thread safety
class MyAdapter(ThreadSafeAdapter): __thread_safe__ = TrueUse
@thread_safe_methoddecorator@thread_safe_method def my_method(self): ...Minimize lock duration
# Compute first result = expensive_operation() # THEN lock with self.acquire_lock(): self.results.append(result)Use thread-local storage for caches
cache = self.get_thread_local('cache', {})Test with concurrent load
ThreadPoolExecutor(max_workers=20)
Don'ts ❌
Don't modify shared state without locking
# ❌ WRONG self.counter += 1Don't use thread-unsafe libraries carelessly
# ❌ WRONG self.df = self.df.append(row) # PandasDon't hold locks during I/O
# ❌ WRONG with self.acquire_lock(): requests.get(url) # Blocks other threads!Don't nest locks (unless reentrant)
# ⚠️ CAREFUL with lock_a: with lock_b: # Potential deadlock ...Don't skip testing
# ❌ WRONG # No concurrent tests = hidden race conditions
Examples
Example 1: Simple Counter Adapter
from snakepit_bridge.base_adapter_threaded import ThreadSafeAdapter, thread_safe_method, tool
class CounterAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
self.count = 0
@thread_safe_method
@tool(description="Increment counter")
def increment(self) -> dict:
with self.acquire_lock():
self.count += 1
return {"count": self.count}
@thread_safe_method
@tool(description="Get current count")
def get_count(self) -> dict:
with self.acquire_lock():
return {"count": self.count}Example 2: ML Model Adapter with Caching
import numpy as np
import torch
class MLAdapter(ThreadSafeAdapter):
__thread_safe__ = True
def __init__(self):
super().__init__()
# Pattern 1: Shared read-only
self.model = torch.load("model.pt")
self.model.eval()
# Pattern 3: Shared mutable
self.total_predictions = 0
@thread_safe_method
@tool(description="ML inference with caching")
def predict(self, input_data: list, cache_key: str = None) -> dict:
# Pattern 2: Thread-local cache
if cache_key:
cache = self.get_thread_local('cache', {})
if cache_key in cache:
return {"prediction": cache[cache_key], "cached": True}
# Compute (NO LOCK - parallel!)
tensor = torch.tensor(input_data)
with torch.no_grad():
output = self.model(tensor)
prediction = output.tolist()
# Update cache (thread-local, no lock needed)
if cache_key:
cache[cache_key] = prediction
self.set_thread_local('cache', cache)
# Update counter (shared, lock needed)
with self.acquire_lock():
self.total_predictions += 1
return {"prediction": prediction, "cached": False}
@thread_safe_method
@tool(description="Get adapter statistics")
def get_stats(self) -> dict:
with self.acquire_lock():
stats = self.get_stats_dict()
stats['total_predictions'] = self.total_predictions
return statsExample 3: Full Production Adapter
See /priv/python/snakepit_bridge/adapters/threaded_showcase.py for a comprehensive 400-line example demonstrating all three safety patterns.
Summary
Key Takeaways
- Three Patterns: Read-only, thread-local, locked mutable
- Minimize Locks: Compute without locks, lock only for updates
- Declare Safety: Always add
__thread_safe__ = True - Test Concurrently: Use ThreadPoolExecutor with 20+ workers
- Check Libraries: Use thread-safe libraries (NumPy, PyTorch, Polars)
Checklist for Thread-Safe Adapters
- [ ] Inherits from
ThreadSafeAdapter - [ ] Has
__thread_safe__ = Truedeclaration - [ ] Uses
@thread_safe_methodon all public methods - [ ] Shared mutable state protected with locks
- [ ] Lock duration minimized
- [ ] Thread-local storage used for caches
- [ ] Only thread-safe libraries used (or properly locked)
- [ ] Tested with concurrent requests (100+)
- [ ] Thread safety checker passes
- [ ] No race conditions or deadlocks
Next Steps
- Read: README_THREADING.md
- Study: threaded_showcase.py
- Test: test_thread_safety.py
- Deploy: Production deployment guide (coming soon)
Questions? Open an issue or check the FAQ in Migration Guide.