Handle CORS and selectively cache responses for appropriate methods #1

Merged
prathamesh merged 12 commits from sk-cors-fix into main 2025-08-01 10:37:07 +00:00
12 changed files with 352 additions and 137 deletions

View File

@ -1,4 +1,4 @@
# Provider endpoints and auth
# Provider endpoints and auth (optional)
ALCHEMY_API_KEY=your_key_here
HELIUS_API_KEY=your_key_here
QUICKNODE_ENDPOINT=your_endpoint.quiknode.pro
@ -6,9 +6,9 @@ QUICKNODE_TOKEN=your_token_here
# Proxy settings
PROXY_PORT=8545
CACHE_SIZE_GB=100
CACHE_SIZE_GB=1
BACKOFF_MINUTES=30
# Logging
LOG_LEVEL=INFO
ERROR_DB_PATH=./errors.db
ERROR_DB_PATH=./errors.db

5
.gitignore vendored
View File

@ -1 +1,4 @@
.env
.env
__pycache__
*.db
cache

View File

@ -1 +1,23 @@
The trenches are brutal.
# Solana Proxy
## Setup
- Copy `.env.example` to `.env`:
```bash
cp .env.example .env
```
- The proxy will work without making any changes to the `.env` file but you can optionally set the API keys for different providers
## Run
- Start the proxy:
```bash
python3 main.py
```
- This will start the proxy with,
- RPC endpoint at: <http://0.0.0.0:8545>
- WS endpoint at: <ws://0.0.0.0:8545/ws>

View File

@ -1,11 +1,12 @@
import json
import os
import time
from typing import Dict, Any, Optional
import diskcache
class Cache:
def __init__(self, cache_dir: str = "./cache", size_limit_gb: int = 100):
def __init__(self, cache_dir: str = "./cache", size_limit_gb: int = 1):
self.cache_dir = cache_dir
self.size_limit_bytes = size_limit_gb * 1024 * 1024 * 1024
self.cache = diskcache.Cache(
@ -13,18 +14,41 @@ class Cache:
size_limit=self.size_limit_bytes,
eviction_policy='least-recently-used'
)
def _make_key(self, method: str, params: Dict[str, Any]) -> str:
return f"{method}:{json.dumps(params, sort_keys=True)}"
def get(self, method: str, params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
key = self._make_key(method, params)
return self.cache.get(key)
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any]) -> None:
cached_data = self.cache.get(key)
if cached_data is None:
return None
# Check if cached data has TTL and if it's expired
if isinstance(cached_data, dict) and '_cache_expiry' in cached_data:
if time.time() > cached_data['_cache_expiry']:
# Remove expired entry
self.cache.delete(key)
return None
# Remove cache metadata before returning
response = cached_data.copy()
del response['_cache_expiry']
return response
return cached_data
def set(self, method: str, params: Dict[str, Any], response: Dict[str, Any], ttl: Optional[int] = None) -> None:
key = self._make_key(method, params)
self.cache.set(key, response)
# Add TTL metadata if specified
if ttl is not None:
cached_response = response.copy()
cached_response['_cache_expiry'] = time.time() + ttl
self.cache.set(key, cached_response)
else:
self.cache.set(key, response)
def size_check(self) -> Dict[str, Any]:
stats = self.cache.stats()
return {
@ -33,4 +57,3 @@ class Cache:
"count": stats[0],
"limit_gb": self.size_limit_bytes / (1024 * 1024 * 1024)
}

96
cache_policy.py Normal file
View File

@ -0,0 +1,96 @@
from typing import Dict, Any, Optional
import time
class CachePolicy:
"""
Determines caching behavior for Solana RPC methods based on their characteristics.
"""
# Methods that return immutable data - cache indefinitely
CACHEABLE_IMMUTABLE = {
'getGenesisHash' # Network genesis hash never changes
}
# Methods with time-based TTL caching based on data change frequency
CACHEABLE_WITH_TTL = {
# Network/validator information - changes periodically
'getVoteAccounts': 120, # Validator vote accounts change every few minutes
'getSupply': 300, # Total SOL supply changes slowly
# Epoch and network info - changes with epoch boundaries (~2-3 days)
'getEpochInfo': 3600, # Current epoch info changes slowly
'getInflationRate': 1800, # Inflation rate changes infrequently
'getInflationGovernor': 3600, # Inflation governor params rarely change
# Network constants - change very rarely or never
'getEpochSchedule': 86400, # Epoch schedule rarely changes
'getVersion': 3600, # RPC version changes occasionally
'getIdentity': 3600, # Node identity changes rarely
# Never change for the given parameters but will add new entry in the DB if the input parameters change
'getBlock': 86400,
'getTransaction':86400
}
def should_cache(self, method: str, params: Dict[str, Any]) -> bool:
"""
Determine if a method should be cached based on the method name and parameters.
Args:
method: The RPC method name
params: The method parameters
Returns:
True if the method should be cached, False otherwise
"""
if method in self.CACHEABLE_WITH_TTL:
# For getBlock, only cache finalized blocks
if method == 'getBlock':
commitment = self._get_commitment(params)
return commitment == 'finalized'
return True
if method in self.CACHEABLE_IMMUTABLE:
return True
# Default to not caching unknown methods
return False
def get_cache_ttl(self, method: str, params: Dict[str, Any]) -> Optional[int]:
"""
Get the Time To Live (TTL) for a cached method in seconds.
Args:
method: The RPC method name
params: The method parameters
Returns:
TTL in seconds, or None for indefinite caching
"""
if method in self.CACHEABLE_IMMUTABLE:
return None # Cache indefinitely
if method in self.CACHEABLE_WITH_TTL:
return self.CACHEABLE_WITH_TTL[method]
return None
def _get_commitment(self, params: Dict[str, Any]) -> str:
"""
Extract the commitment level from RPC parameters.
Args:
params: The method parameters
Returns:
The commitment level, defaults to 'processed'
"""
if isinstance(params, list) and len(params) > 1:
if isinstance(params[1], dict) and 'commitment' in params[1]:
return params[1]['commitment']
elif isinstance(params, dict) and 'commitment' in params:
return params['commitment']
return 'processed' # Default commitment level

View File

@ -11,7 +11,7 @@ A Python-based reverse proxy for Solana RPC endpoints that provides unified acce
```
Provider class:
- name: str
- http_url: str
- http_url: str
- ws_url: str
- transform_request(request) -> request
- transform_response(response) -> response
@ -34,15 +34,15 @@ Provider class:
Cache class:
- get(method: str, params: dict) -> Optional[response]
- set(method: str, params: dict, response: dict) -> None
- size_check() -> None # Enforce 100GB limit
- size_check() -> None # Enforce 1GB limit
- clear_oldest() -> None # LRU eviction
```
**Implementation Notes**:
- Use `diskcache` library for simplicity
- Key format: `f"{method}:{json.dumps(params, sort_keys=True)}"`
- Key format: `f"{method}:{json.dumps(params, sort_keys=True)}"`
- Store both HTTP responses and WebSocket messages
- Implement 100GB limit with LRU eviction
- Implement 1GB limit with LRU eviction
### 3. Error Logger Module (`errors.py`)
**Purpose**: SQLite-based error logging with UUID tracking
@ -90,7 +90,7 @@ Router class:
- providers: List[Provider]
- cache: Cache
- error_logger: ErrorLogger
-
-
- route_request(method: str, params: dict) -> response
- get_available_provider() -> Optional[Provider]
- mark_provider_failed(provider: Provider) -> None
@ -146,7 +146,7 @@ QUICKNODE_TOKEN=your_token_here
# Proxy settings
PROXY_PORT=8545
CACHE_SIZE_GB=100
CACHE_SIZE_GB=1
BACKOFF_MINUTES=30
# Logging
@ -227,7 +227,7 @@ Happy-path end-to-end tests only:
## Deployment Considerations
1. **Cache Storage**: Need ~100GB disk space
1. **Cache Storage**: Need ~1GB disk space
2. **Memory Usage**: Keep minimal, use disk cache
3. **Concurrent Clients**: Basic round-robin if multiple connect
4. **Monitoring**: Log all errors, provide error IDs
@ -273,8 +273,8 @@ aiohttp-cors==0.7.0
1. Single endpoint proxies to 5 providers
2. Automatic failover works
3. Responses are cached (up to 100GB)
3. Responses are cached (up to 1GB)
4. Errors logged with retrievable IDs
5. Both HTTP and WebSocket work
6. Response format is unified
7. Happy-path tests pass
7. Happy-path tests pass

View File

@ -7,10 +7,10 @@ from router import Router
async def handle_rpc_request(request: web.Request) -> web.Response:
router: Router = request.app['router']
logger = logging.getLogger(__name__)
try:
body = await request.json()
if not isinstance(body, dict):
return web.json_response({
"jsonrpc": "2.0",
@ -20,11 +20,11 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
"message": "Invalid Request"
}
}, status=400)
method = body.get("method")
params = body.get("params", [])
request_id = body.get("id", 1)
if not method:
return web.json_response({
"jsonrpc": "2.0",
@ -34,14 +34,15 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
"message": "Missing method"
}
}, status=400)
logger.info(f"Handling RPC request: {method}")
response = await router.route_request(method, params)
response["id"] = request_id
return web.json_response(response)
except json.JSONDecodeError:
return web.json_response({
"jsonrpc": "2.0",
@ -51,7 +52,7 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
"message": "Parse error"
}
}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {e}")
return web.json_response({
@ -64,10 +65,5 @@ async def handle_rpc_request(request: web.Request) -> web.Response:
}, status=500)
def setup_routes(app: web.Application) -> None:
app.router.add_post('/', handle_rpc_request)

44
main.py
View File

@ -1,5 +1,6 @@
import os
import logging
import asyncio
from dotenv import load_dotenv
from aiohttp import web
from providers import create_providers
@ -10,12 +11,31 @@ from http_proxy import setup_routes
from ws_proxy import setup_ws_routes
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers to all responses"""
if request.method == 'OPTIONS':
# Handle preflight requests
return web.Response(headers={
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, OPTIONS',
'Access-Control-Allow-Headers': '*',
'Access-Control-Max-Age': '86400'
})
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = '*'
return response
def load_config() -> dict:
load_dotenv()
return {
"proxy_port": int(os.getenv("PROXY_PORT", 8545)),
"cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 100)),
"cache_size_gb": int(os.getenv("CACHE_SIZE_GB", 1)),
"backoff_minutes": int(os.getenv("BACKOFF_MINUTES", 30)),
"log_level": os.getenv("LOG_LEVEL", "INFO"),
"error_db_path": os.getenv("ERROR_DB_PATH", "./errors.db"),
@ -30,35 +50,33 @@ def setup_logging(log_level: str) -> None:
def create_app(config: dict) -> web.Application:
app = web.Application()
app = web.Application(middlewares=[cors_middleware])
providers = create_providers()
cache = Cache(size_limit_gb=config["cache_size_gb"])
error_logger = ErrorLogger(db_path=config["error_db_path"])
router = Router(providers, cache, error_logger)
app['router'] = router
app['config'] = config
setup_routes(app)
setup_ws_routes(app)
return app
def main() -> None:
config = load_config()
setup_logging(config["log_level"])
logger = logging.getLogger(__name__)
logger.info(f"Starting Solana RPC Proxy on port {config['proxy_port']}")
logger.info(f"Cache size limit: {config['cache_size_gb']}GB")
logger.info(f"Provider backoff time: {config['backoff_minutes']} minutes")
app = create_app(config)
web.run_app(
app,
host='0.0.0.0',
@ -67,4 +85,4 @@ def main() -> None:
if __name__ == "__main__":
main()
main()

View File

@ -3,16 +3,16 @@ from typing import Dict, Any
def normalize_response(provider: str, response: Dict[str, Any]) -> Dict[str, Any]:
normalized = response.copy()
# Ensure consistent field names
if "result" in normalized and normalized["result"] is None:
# Some providers return null, others omit the field
pass
# Handle null vs missing fields consistently
if "error" in normalized and normalized["error"] is None:
del normalized["error"]
return normalized
@ -25,4 +25,4 @@ def normalize_error(error: Exception, error_id: str) -> Dict[str, Any]:
"message": str(error),
"data": {"error_id": error_id}
}
}
}

View File

@ -8,28 +8,28 @@ class Provider(ABC):
def __init__(self, name: str):
self.name = name
self.backoff_until: Optional[datetime] = None
@property
@abstractmethod
def http_url(self) -> str:
pass
@property
@abstractmethod
def ws_url(self) -> str:
pass
def transform_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
return request
def transform_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
return response
def is_available(self) -> bool:
if self.backoff_until is None:
return True
return datetime.now() > self.backoff_until
def mark_failed(self, backoff_minutes: int = 30) -> None:
self.backoff_until = datetime.now() + timedelta(minutes=backoff_minutes)
@ -38,11 +38,11 @@ class AlchemyProvider(Provider):
def __init__(self):
super().__init__("alchemy")
self.api_key = os.getenv("ALCHEMY_API_KEY", "")
@property
def http_url(self) -> str:
return f"https://solana-mainnet.g.alchemy.com/v2/{self.api_key}"
@property
def ws_url(self) -> str:
return f"wss://solana-mainnet.g.alchemy.com/v2/{self.api_key}"
@ -51,11 +51,11 @@ class AlchemyProvider(Provider):
class PublicNodeProvider(Provider):
def __init__(self):
super().__init__("publicnode")
@property
def http_url(self) -> str:
return "https://solana-rpc.publicnode.com"
@property
def ws_url(self) -> str:
return "wss://solana-rpc.publicnode.com"
@ -65,11 +65,11 @@ class HeliusProvider(Provider):
def __init__(self):
super().__init__("helius")
self.api_key = os.getenv("HELIUS_API_KEY", "")
@property
def http_url(self) -> str:
return f"https://mainnet.helius-rpc.com/?api-key={self.api_key}"
@property
def ws_url(self) -> str:
return f"wss://mainnet.helius-rpc.com/?api-key={self.api_key}"
@ -80,11 +80,11 @@ class QuickNodeProvider(Provider):
super().__init__("quicknode")
self.endpoint = os.getenv("QUICKNODE_ENDPOINT", "")
self.token = os.getenv("QUICKNODE_TOKEN", "")
@property
def http_url(self) -> str:
return f"https://{self.endpoint}/{self.token}/"
@property
def ws_url(self) -> str:
return f"wss://{self.endpoint}/{self.token}/"
@ -93,11 +93,11 @@ class QuickNodeProvider(Provider):
class SolanaPublicProvider(Provider):
def __init__(self):
super().__init__("solana_public")
@property
def http_url(self) -> str:
return "https://api.mainnet-beta.solana.com"
@property
def ws_url(self) -> str:
return "wss://api.mainnet-beta.solana.com"
@ -105,9 +105,9 @@ class SolanaPublicProvider(Provider):
def create_providers() -> list[Provider]:
return [
SolanaPublicProvider(),
AlchemyProvider(),
PublicNodeProvider(),
HeliusProvider(),
QuickNodeProvider(),
SolanaPublicProvider()
]
]

View File

@ -4,6 +4,7 @@ import logging
from typing import Dict, Any, Optional, List
from providers import Provider
from cache import Cache
from cache_policy import CachePolicy
from errors import ErrorLogger
@ -11,20 +12,25 @@ class Router:
def __init__(self, providers: List[Provider], cache: Cache, error_logger: ErrorLogger):
self.providers = providers
self.cache = cache
self.cache_policy = CachePolicy()
self.error_logger = error_logger
self.current_provider_index = 0
self.logger = logging.getLogger(__name__)
async def route_request(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
request = {"method": method, "params": params}
cached_response = self.cache.get(method, params)
if cached_response:
self.logger.debug(f"Cache hit for {method}")
cached_response["_cached"] = True
cached_response["_provider"] = "cache"
return cached_response
# Check if this method should be cached based on caching policy
should_cache = self.cache_policy.should_cache(method, params)
if should_cache:
cached_response = self.cache.get(method, params)
if cached_response:
self.logger.debug(f"Cache hit for {method}")
cached_response["_cached"] = True
cached_response["_provider"] = "cache"
return cached_response
for attempt in range(len(self.providers)):
provider = self.get_next_available_provider()
if not provider:
@ -32,48 +38,84 @@ class Router:
"All providers are currently unavailable",
"NO_AVAILABLE_PROVIDERS"
)
try:
response = await self._make_request(provider, request)
transformed_response = provider.transform_response(response)
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
self.cache.set(method, params, transformed_response)
self.logger.info(f"Request succeeded via {provider.name}")
# Cache the response if caching policy allows it
if should_cache:
ttl = self.cache_policy.get_cache_ttl(method, params)
self.cache.set(method, params, transformed_response, ttl)
cache_info = f" (cached {'indefinitely' if ttl is None else f'for {ttl}s'})"
self.logger.info(f"Request succeeded via {provider.name}{cache_info}")
else:
self.logger.info(f"Request succeeded via {provider.name} (not cached)")
return transformed_response
except Exception as error:
error_id = self.error_logger.log_error(provider.name, request, error)
self.logger.warning(f"Provider {provider.name} failed: {error} (ID: {error_id})")
provider.mark_failed()
# Only mark provider as failed for server/network issues, not RPC errors
if await self._is_server_failure(provider, error):
provider.mark_failed()
self.logger.warning(f"Provider {provider.name} marked as failed due to server issue")
else:
self.logger.debug(f"Provider {provider.name} had RPC error but server is available")
return self._create_error_response(
"All providers failed to handle the request",
"ALL_PROVIDERS_FAILED"
)
def get_next_available_provider(self) -> Optional[Provider]:
for _ in range(len(self.providers)):
provider = self.providers[self.current_provider_index]
self.current_provider_index = (self.current_provider_index + 1) % len(self.providers)
if provider.is_available():
return provider
return None
async def _is_server_failure(self, provider: Provider, error: Exception) -> bool:
"""
Check if the provider server is actually down by making a simple health check.
Only mark as failed if server is unreachable.
"""
try:
# Quick health check with minimal timeout
timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
# Try a simple HTTP GET to check server availability
from urllib.parse import urlparse
parsed_url = urlparse(provider.http_url)
health_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
async with session.get(health_url) as response:
# Server responded (even with error codes), so it's alive
return False
except Exception as health_error:
# Server is actually unreachable
self.logger.debug(f"Health check failed for {provider.name}: {health_error}")
return True
async def _make_request(self, provider: Provider, request: Dict[str, Any]) -> Dict[str, Any]:
transformed_request = provider.transform_request(request)
rpc_request = {
"jsonrpc": "2.0",
"id": 1,
"method": transformed_request["method"],
"params": transformed_request["params"]
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
@ -83,14 +125,14 @@ class Router:
) as response:
if response.status != 200:
raise Exception(f"HTTP {response.status}: {await response.text()}")
result = await response.json()
if "error" in result:
raise Exception(f"RPC Error: {result['error']}")
return result
def _create_error_response(self, message: str, code: str) -> Dict[str, Any]:
return {
"jsonrpc": "2.0",

View File

@ -11,110 +11,125 @@ class WebSocketProxy:
self.router = router
self.logger = logging.getLogger(__name__)
self.subscription_mappings: Dict[str, str] = {}
async def handle_ws_connection(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self.logger.info("New WebSocket connection established")
provider = self.router.get_next_available_provider()
if not provider:
await ws.close(code=1011, message=b'No available providers')
return ws
try:
provider_ws = await self._connect_to_provider(provider)
if not provider_ws:
provider_connection = await self._connect_to_provider(provider)
if not provider_connection:
await ws.close(code=1011, message=b'Failed to connect to provider')
return ws
provider_ws, provider_session = provider_connection
await asyncio.gather(
self._proxy_client_to_provider(ws, provider_ws, provider),
self._proxy_provider_to_client(provider_ws, ws, provider),
return_exceptions=True
)
# Clean up provider connection
if not provider_ws.closed:
await provider_ws.close()
await provider_session.close()
except Exception as e:
self.logger.error(f"WebSocket proxy error: {e}")
finally:
if not ws.closed:
await ws.close()
return ws
async def _connect_to_provider(self, provider) -> Optional[object]:
async def _connect_to_provider(self, provider) -> Optional[tuple]:
session = None
try:
session = ClientSession()
self.logger.info(f"Attempting to connect to provider {provider.name} at {provider.ws_url}")
ws = await session.ws_connect(provider.ws_url)
self.logger.info(f"Connected to provider {provider.name} WebSocket")
return ws
self.logger.info(f"Successfully connected to provider {provider.name} WebSocket at {provider.ws_url}")
return (ws, session)
except Exception as e:
self.logger.error(f"Failed to connect to provider {provider.name}: {e}")
self.logger.error(f"Failed to connect to provider {provider.name} at {provider.ws_url}: {e}")
if session:
await session.close()
return None
async def _proxy_client_to_provider(self, client_ws: web.WebSocketResponse, provider_ws, provider) -> None:
async for msg in client_ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
method = data.get('method', 'unknown')
self.logger.debug(f"Received from client: {data}")
transformed_request = provider.transform_request(data)
await provider_ws.send_str(json.dumps(transformed_request))
self.logger.debug(f"Forwarded message to {provider.name}: {data.get('method', 'unknown')}")
self.logger.debug(f"Forwarded message to {provider.name}: {method}")
except json.JSONDecodeError:
self.logger.warning("Received invalid JSON from client")
except Exception as e:
self.logger.error(f"Error forwarding to provider: {e}")
break
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'WebSocket error: {client_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
break
async def _proxy_provider_to_client(self, provider_ws, client_ws: web.WebSocketResponse, provider) -> None:
async for msg in provider_ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
self.logger.debug(f"Received from provider {provider.name}: {data}")
transformed_response = provider.transform_response(data)
if "result" in transformed_response and "subscription" in str(transformed_response.get("result", {})):
subscription_id = transformed_response.get("result")
if subscription_id:
self.subscription_mappings[str(subscription_id)] = provider.name
self.logger.debug(f"SIGNATURE_SUBSCRIBE: Mapped subscription {subscription_id} to {provider.name}")
transformed_response["_cached"] = False
transformed_response["_provider"] = provider.name
method = transformed_response.get("method", "")
params = transformed_response.get("params", {})
if method and params:
self.router.cache.set(method, params, transformed_response)
await client_ws.send_str(json.dumps(transformed_response))
self.logger.debug(f"Forwarded response from {provider.name}")
self.logger.debug(f"Forwarded response to client from {provider.name}: {transformed_response}")
except json.JSONDecodeError:
self.logger.warning(f"Received invalid JSON from provider {provider.name}")
except Exception as e:
self.logger.error(f"Error forwarding from provider: {e}")
break
# Don't break here - continue processing other messages
continue
elif msg.type == WSMsgType.ERROR:
self.logger.error(f'Provider WebSocket error: {provider_ws.exception()}')
break
elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING):
self.logger.warning(f"Provider WebSocket connection closed from {provider.name}")
break
self.logger.warning(f"Provider-to-client message loop ended for {provider.name}")
async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
router: Router = request.app['router']
@ -123,4 +138,4 @@ async def handle_ws_connection(request: web.Request) -> web.WebSocketResponse:
def setup_ws_routes(app: web.Application) -> None:
app.router.add_get('/ws', handle_ws_connection)
app.router.add_get('/ws', handle_ws_connection)