first version
config CORS fixed key to one line helper prints clean up logs improved validations again validations fix rewritten flask and node.js solution added subscription route auth flow diagrams
This commit is contained in:
289
app.py
Normal file
289
app.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
import signal
|
||||
import pathlib
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from dotenv import load_dotenv
|
||||
from pywebpush import webpush, WebPushException
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, os.getenv('LOG_LEVEL', 'INFO')),
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# CORS Configuration
|
||||
ALLOWED_ORIGINS = [
|
||||
"https://game-timer.virtonline.eu",
|
||||
# Add other allowed origins if needed
|
||||
]
|
||||
ALLOWED_METHODS = ["POST", "OPTIONS"]
|
||||
ALLOWED_HEADERS = ["Content-Type"]
|
||||
|
||||
class FlicButtonHandler:
|
||||
def __init__(self):
|
||||
# Load button configurations
|
||||
self.button_configs = {
|
||||
os.getenv('FLIC_BUTTON1_SERIAL'): self.handle_button1,
|
||||
os.getenv('FLIC_BUTTON2_SERIAL'): self.handle_button2,
|
||||
os.getenv('FLIC_BUTTON3_SERIAL'): self.handle_button3
|
||||
}
|
||||
|
||||
# Ensure subscriptions file and directory exist
|
||||
self.subscriptions_file = os.getenv('SUBSCRIPTIONS_FILE', 'app/subscriptions.json')
|
||||
self._ensure_subscriptions_file()
|
||||
|
||||
# Load subscriptions
|
||||
self.subscriptions = self.load_subscriptions()
|
||||
|
||||
# Prepare VAPID keys
|
||||
self.vapid_private_key = self._decode_vapid_private_key()
|
||||
|
||||
def _ensure_subscriptions_file(self):
|
||||
"""
|
||||
Ensure the subscriptions file and its parent directory exist.
|
||||
Create them if they don't.
|
||||
"""
|
||||
try:
|
||||
# Create parent directory if it doesn't exist
|
||||
pathlib.Path(self.subscriptions_file).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create file if it doesn't exist
|
||||
if not os.path.exists(self.subscriptions_file):
|
||||
with open(self.subscriptions_file, 'w') as f:
|
||||
json.dump([], f)
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring subscriptions file: {e}")
|
||||
raise
|
||||
|
||||
def _decode_vapid_private_key(self):
|
||||
"""Load and strictly validate VAPID private key."""
|
||||
try:
|
||||
# Get and clean the key
|
||||
env_key = os.getenv('VAPID_PRIVATE_KEY', '').strip()
|
||||
|
||||
# Reconstruct PEM format if missing headers
|
||||
if not env_key.startswith('-----BEGIN PRIVATE KEY-----'):
|
||||
env_key = f"-----BEGIN PRIVATE KEY-----\n{env_key}\n-----END PRIVATE KEY-----"
|
||||
|
||||
# Strict validation and key preparation
|
||||
key = serialization.load_pem_private_key(
|
||||
env_key.encode('utf-8'),
|
||||
password=None,
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
# Return in strict PEM format
|
||||
return key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption()
|
||||
).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"CRITICAL: Invalid VAPID key - {str(e)}")
|
||||
raise RuntimeError("Invalid VAPID private key configuration") from e
|
||||
|
||||
def load_subscriptions(self) -> List[Dict]:
|
||||
"""Load web push subscriptions from file."""
|
||||
try:
|
||||
with open(self.subscriptions_file, 'r') as f:
|
||||
# Handle empty file case
|
||||
content = f.read().strip()
|
||||
return json.loads(content) if content else []
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Error decoding subscriptions from {self.subscriptions_file}")
|
||||
return []
|
||||
|
||||
def save_subscriptions(self):
|
||||
"""Save web push subscriptions to file."""
|
||||
try:
|
||||
with open(self.subscriptions_file, 'w') as f:
|
||||
json.dump(self.subscriptions, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving subscriptions: {e}")
|
||||
|
||||
async def send_push_notification(self, subscription: Dict, message: str):
|
||||
try:
|
||||
# Determine audience (aud) claim for VAPID
|
||||
endpoint = subscription['endpoint']
|
||||
aud = (endpoint.split('/send')[0] if '/send' in endpoint
|
||||
else endpoint.split('/fcm/send')[0])
|
||||
|
||||
logger.debug(f"Sending to: {endpoint[:50]}...")
|
||||
logger.debug(f"Using aud: {aud}")
|
||||
|
||||
# Perform web push
|
||||
result = webpush(
|
||||
subscription_info=subscription,
|
||||
data=message,
|
||||
vapid_private_key=self.vapid_private_key,
|
||||
vapid_claims={
|
||||
"sub": os.getenv('VAPID_CLAIM_EMAIL'),
|
||||
"aud": aud + "/" # Ensure trailing slash
|
||||
},
|
||||
ttl=86400 # 24 hour expiration
|
||||
)
|
||||
logger.info("Push notification sent successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Push failed: {str(e)}")
|
||||
logger.error(f"Endpoint details: {subscription['endpoint']}")
|
||||
logger.error(f"Keys: {subscription.get('keys', 'No keys found')}")
|
||||
return False
|
||||
|
||||
async def handle_button1(self):
|
||||
"""Handle first button action - e.g., Home Lights On"""
|
||||
logger.info("Button 1 pressed: Home Lights On")
|
||||
message = json.dumps({"action": "home_lights_on"})
|
||||
await self.broadcast_notification(message)
|
||||
|
||||
async def handle_button2(self):
|
||||
"""Handle second button action - e.g., Security System Arm"""
|
||||
logger.info("Button 2 pressed: Security System Arm")
|
||||
message = json.dumps({"action": "security_arm"})
|
||||
await self.broadcast_notification(message)
|
||||
|
||||
async def handle_button3(self):
|
||||
"""Handle third button action - e.g., Panic Button"""
|
||||
logger.info("Button 3 pressed: Panic Alert")
|
||||
message = json.dumps({"action": "panic_alert"})
|
||||
await self.broadcast_notification(message)
|
||||
|
||||
async def broadcast_notification(self, message: str):
|
||||
"""Broadcast notification to all subscriptions with error handling."""
|
||||
if not self.subscriptions:
|
||||
logger.warning("No subscriptions to broadcast to")
|
||||
return
|
||||
|
||||
success_count = 0
|
||||
for subscription in self.subscriptions:
|
||||
try:
|
||||
success = await self.send_push_notification(subscription, message)
|
||||
if success:
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to {subscription['endpoint'][:30]}...: {str(e)}")
|
||||
|
||||
logger.info(f"Notifications sent: {success_count}/{len(self.subscriptions)}")
|
||||
|
||||
async def handle_flic_webhook(self, request):
|
||||
"""Webhook endpoint for Flic button events."""
|
||||
try:
|
||||
data = await request.json()
|
||||
button_serial = data.get('serial')
|
||||
|
||||
# Validate button serial
|
||||
if button_serial not in self.button_configs:
|
||||
logger.warning(f"Unknown button serial: {button_serial}")
|
||||
return web.Response(status=400)
|
||||
|
||||
# Call the corresponding button handler
|
||||
handler = self.button_configs[button_serial]
|
||||
await handler()
|
||||
|
||||
return web.Response(status=200)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Flic webhook: {e}")
|
||||
return web.Response(status=500)
|
||||
|
||||
async def handle_subscribe(self, request):
|
||||
"""Add a new web push subscription."""
|
||||
try:
|
||||
subscription = await request.json()
|
||||
|
||||
# Check if subscription already exists
|
||||
if subscription not in self.subscriptions:
|
||||
self.subscriptions.append(subscription)
|
||||
self.save_subscriptions()
|
||||
logger.info("New subscription added")
|
||||
|
||||
return web.Response(status=200)
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription error: {e}")
|
||||
return web.Response(status=500)
|
||||
|
||||
def create_app():
|
||||
"""Create and configure the aiohttp application."""
|
||||
app = web.Application()
|
||||
handler = FlicButtonHandler()
|
||||
|
||||
async def options_handler(request):
|
||||
"""Handle OPTIONS requests for CORS preflight."""
|
||||
origin = request.headers.get('Origin', '')
|
||||
if origin in ALLOWED_ORIGINS:
|
||||
headers = {
|
||||
'Access-Control-Allow-Origin': origin,
|
||||
'Access-Control-Allow-Methods': ', '.join(ALLOWED_METHODS),
|
||||
'Access-Control-Allow-Headers': ', '.join(ALLOWED_HEADERS),
|
||||
'Access-Control-Max-Age': '86400', # 24 hours
|
||||
}
|
||||
return web.Response(status=200, headers=headers)
|
||||
return web.Response(status=403) # Forbidden origin
|
||||
|
||||
async def add_cors_headers(request, response):
|
||||
"""Add CORS headers to normal responses."""
|
||||
origin = request.headers.get('Origin', '')
|
||||
if origin in ALLOWED_ORIGINS:
|
||||
response.headers['Access-Control-Allow-Origin'] = origin
|
||||
response.headers['Access-Control-Expose-Headers'] = 'Content-Type'
|
||||
return response
|
||||
|
||||
# Register middleware
|
||||
app.on_response_prepare.append(add_cors_headers)
|
||||
|
||||
# Setup routes with OPTIONS handlers
|
||||
app.router.add_route('OPTIONS', '/flic-webhook', options_handler)
|
||||
app.router.add_route('OPTIONS', '/subscribe', options_handler)
|
||||
|
||||
# Original routes
|
||||
app.router.add_post('/flic-webhook', handler.handle_flic_webhook)
|
||||
app.router.add_post('/subscribe', handler.handle_subscribe)
|
||||
|
||||
return app
|
||||
|
||||
async def main():
|
||||
"""Main application entry point."""
|
||||
app = create_app()
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, '0.0.0.0', 8080)
|
||||
await site.start()
|
||||
|
||||
logger.info("Application started on port 8080")
|
||||
|
||||
# Create an event to keep the application running
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
"""Handle interrupt signals to gracefully stop the application."""
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
# Register signal handlers
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Wait until stop event is set
|
||||
await stop_event.wait()
|
||||
|
||||
# Cleanup
|
||||
await runner.cleanup()
|
||||
logger.info("Application shutting down")
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user