import asyncio import json import logging import os import base64 from typing import Dict, List import signal import pathlib import aiohttp from aiohttp import web from dotenv import load_dotenv from pywebpush import webpush, WebPushException from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec # Load environment variables load_dotenv() # Configure logging logging.basicConfig( level=getattr(logging, os.getenv('LOG_LEVEL', 'INFO')), format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # CORS Configuration ALLOWED_ORIGINS = [ "https://game-timer.virtonline.eu", # Add other allowed origins if needed ] ALLOWED_METHODS = ["POST", "OPTIONS"] ALLOWED_HEADERS = ["Content-Type"] class FlicButtonHandler: def __init__(self): # Load button configurations self.button_configs = { os.getenv('FLIC_BUTTON1_SERIAL'): self.handle_button1, os.getenv('FLIC_BUTTON2_SERIAL'): self.handle_button2, os.getenv('FLIC_BUTTON3_SERIAL'): self.handle_button3 } # Ensure subscriptions file and directory exist self.subscriptions_file = os.getenv('SUBSCRIPTIONS_FILE', 'app/subscriptions.json') self._ensure_subscriptions_file() # Load subscriptions self.subscriptions = self.load_subscriptions() # Prepare VAPID keys self.vapid_private_key = self._decode_vapid_private_key() def _ensure_subscriptions_file(self): """ Ensure the subscriptions file and its parent directory exist. Create them if they don't. """ try: # Create parent directory if it doesn't exist pathlib.Path(self.subscriptions_file).parent.mkdir(parents=True, exist_ok=True) # Create file if it doesn't exist if not os.path.exists(self.subscriptions_file): with open(self.subscriptions_file, 'w') as f: json.dump([], f) except Exception as e: logger.error(f"Error ensuring subscriptions file: {e}") raise def _decode_vapid_private_key(self): """Load and strictly validate VAPID private key.""" try: # Get and clean the key env_key = os.getenv('VAPID_PRIVATE_KEY', '').strip().strip('"\'') # Convert to clean PEM format if '\\n' in env_key: private_pem = env_key.replace('\\n', '\n') else: private_pem = env_key # Ensure proper PEM headers if not private_pem.startswith('-----BEGIN PRIVATE KEY-----'): private_pem = f"-----BEGIN PRIVATE KEY-----\n{private_pem}\n-----END PRIVATE KEY-----" # Strict validation key = serialization.load_pem_private_key( private_pem.encode('utf-8'), password=None, backend=default_backend() ) # Return in strict PEM format return key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() ).decode('utf-8') except Exception as e: logger.error(f"CRITICAL: Invalid VAPID key - {str(e)}") raise RuntimeError("Invalid VAPID private key configuration") from e def load_subscriptions(self) -> List[Dict]: """Load web push subscriptions from file.""" try: with open(self.subscriptions_file, 'r') as f: # Handle empty file case content = f.read().strip() return json.loads(content) if content else [] except json.JSONDecodeError: logger.error(f"Error decoding subscriptions from {self.subscriptions_file}") return [] def save_subscriptions(self): """Save web push subscriptions to file.""" try: with open(self.subscriptions_file, 'w') as f: json.dump(self.subscriptions, f, indent=2) except Exception as e: logger.error(f"Error saving subscriptions: {e}") async def send_push_notification(self, subscription: Dict, message: str): try: # Get endpoint base for aud claim endpoint = subscription['endpoint'] aud = endpoint.split('/send')[0] if '/send' in endpoint else endpoint.split('/fcm/send')[0] logger.debug(f"Sending to: {endpoint[:50]}...") logger.debug(f"Using aud: {aud}") webpush( subscription_info=subscription, data=message, vapid_private_key=self.vapid_private_key, vapid_claims={ "sub": os.getenv('VAPID_CLAIM_EMAIL'), "aud": aud + "/" # Ensure trailing slash }, ttl=86400 # 24 hour expiration ) logger.info("Push notification sent successfully") return True except Exception as e: logger.error(f"Push failed: {str(e)}") return False async def handle_button1(self): """Handle first button action - e.g., Home Lights On""" logger.info("Button 1 pressed: Home Lights On") message = json.dumps({"action": "home_lights_on"}) await self.broadcast_notification(message) async def handle_button2(self): """Handle second button action - e.g., Security System Arm""" logger.info("Button 2 pressed: Security System Arm") message = json.dumps({"action": "security_arm"}) await self.broadcast_notification(message) async def handle_button3(self): """Handle third button action - e.g., Panic Button""" logger.info("Button 3 pressed: Panic Alert") message = json.dumps({"action": "panic_alert"}) await self.broadcast_notification(message) async def broadcast_notification(self, message: str): """Broadcast notification to all subscriptions with error handling.""" if not self.subscriptions: logger.warning("No subscriptions to broadcast to") return success_count = 0 for subscription in self.subscriptions: try: await self.send_push_notification(subscription, message) success_count += 1 except Exception as e: logger.error(f"Failed to send to {subscription['endpoint'][:30]}...: {str(e)}") logger.info(f"Notifications sent: {success_count}/{len(self.subscriptions)}") async def handle_flic_webhook(self, request): """Webhook endpoint for Flic button events.""" try: data = await request.json() button_serial = data.get('serial') # Validate button serial if button_serial not in self.button_configs: logger.warning(f"Unknown button serial: {button_serial}") return web.Response(status=400) # Call the corresponding button handler handler = self.button_configs[button_serial] await handler() return web.Response(status=200) except Exception as e: logger.error(f"Error processing Flic webhook: {e}") return web.Response(status=500) async def handle_subscribe(self, request): """Add a new web push subscription.""" try: subscription = await request.json() # Check if subscription already exists if subscription not in self.subscriptions: self.subscriptions.append(subscription) self.save_subscriptions() logger.info("New subscription added") return web.Response(status=200) except Exception as e: logger.error(f"Subscription error: {e}") return web.Response(status=500) def create_app(): """Create and configure the aiohttp application.""" app = web.Application() handler = FlicButtonHandler() async def options_handler(request): """Handle OPTIONS requests for CORS preflight.""" origin = request.headers.get('Origin', '') if origin in ALLOWED_ORIGINS: headers = { 'Access-Control-Allow-Origin': origin, 'Access-Control-Allow-Methods': ', '.join(ALLOWED_METHODS), 'Access-Control-Allow-Headers': ', '.join(ALLOWED_HEADERS), 'Access-Control-Max-Age': '86400', # 24 hours } return web.Response(status=200, headers=headers) return web.Response(status=403) # Forbidden origin async def add_cors_headers(request, response): """Add CORS headers to normal responses.""" origin = request.headers.get('Origin', '') if origin in ALLOWED_ORIGINS: response.headers['Access-Control-Allow-Origin'] = origin response.headers['Access-Control-Expose-Headers'] = 'Content-Type' return response # Register middleware app.on_response_prepare.append(add_cors_headers) # Setup routes with OPTIONS handlers app.router.add_route('OPTIONS', '/flic-webhook', options_handler) app.router.add_route('OPTIONS', '/subscribe', options_handler) # Original routes app.router.add_post('/flic-webhook', handler.handle_flic_webhook) app.router.add_post('/subscribe', handler.handle_subscribe) return app async def main(): """Main application entry point.""" app = create_app() runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, '0.0.0.0', 8080) await site.start() logger.info("Application started on port 8080") # Create an event to keep the application running stop_event = asyncio.Event() def signal_handler(): """Handle interrupt signals to gracefully stop the application.""" logger.info("Received shutdown signal") stop_event.set() # Register signal handlers loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, signal_handler) # Wait until stop event is set await stop_event.wait() # Cleanup await runner.cleanup() logger.info("Application shutting down") if __name__ == '__main__': asyncio.run(main())