Files
flic-webhook-webpush/app.py
2025-03-26 08:16:56 +01:00

289 lines
10 KiB
Python

import asyncio
import json
import logging
import os
import base64
from typing import Dict, List
import signal
import pathlib
import aiohttp
from aiohttp import web
from dotenv import load_dotenv
from pywebpush import webpush, WebPushException
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=getattr(logging, os.getenv('LOG_LEVEL', 'INFO')),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# CORS Configuration
ALLOWED_ORIGINS = [
"https://game-timer.virtonline.eu",
# Add other allowed origins if needed
]
ALLOWED_METHODS = ["POST", "OPTIONS"]
ALLOWED_HEADERS = ["Content-Type"]
class FlicButtonHandler:
def __init__(self):
# Load button configurations
self.button_configs = {
os.getenv('FLIC_BUTTON1_SERIAL'): self.handle_button1,
os.getenv('FLIC_BUTTON2_SERIAL'): self.handle_button2,
os.getenv('FLIC_BUTTON3_SERIAL'): self.handle_button3
}
# Ensure subscriptions file and directory exist
self.subscriptions_file = os.getenv('SUBSCRIPTIONS_FILE', 'app/subscriptions.json')
self._ensure_subscriptions_file()
# Load subscriptions
self.subscriptions = self.load_subscriptions()
# Prepare VAPID keys
self.vapid_private_key = self._decode_vapid_private_key()
def _ensure_subscriptions_file(self):
"""
Ensure the subscriptions file and its parent directory exist.
Create them if they don't.
"""
try:
# Create parent directory if it doesn't exist
pathlib.Path(self.subscriptions_file).parent.mkdir(parents=True, exist_ok=True)
# Create file if it doesn't exist
if not os.path.exists(self.subscriptions_file):
with open(self.subscriptions_file, 'w') as f:
json.dump([], f)
except Exception as e:
logger.error(f"Error ensuring subscriptions file: {e}")
raise
def _decode_vapid_private_key(self):
"""
Final robust VAPID private key loader
Handles all possible key formats and provides detailed debugging
"""
try:
# Get and clean the key from environment
env_key = os.getenv('VAPID_PRIVATE_KEY', '').strip().strip('"\'')
# Debug output
logger.debug(f"Raw env key length: {len(env_key)}")
logger.debug(f"Key starts with: {env_key[:50]}")
# Convert to PEM format
if '\\n' in env_key:
# Handle escaped newlines (from .env file)
private_pem = env_key.replace('\\n', '\n')
elif '-----BEGIN PRIVATE KEY-----' in env_key:
# Already in PEM format
private_pem = env_key
else:
# Assume base64 encoded
private_pem = base64.urlsafe_b64decode(env_key).decode('utf-8')
# Ensure proper PEM format
if not private_pem.startswith('-----BEGIN PRIVATE KEY-----'):
private_pem = f"-----BEGIN PRIVATE KEY-----\n{private_pem}\n-----END PRIVATE KEY-----"
# Final validation
try:
key = serialization.load_pem_private_key(
private_pem.encode('utf-8'),
password=None
)
logger.debug("VAPID private key successfully loaded")
return private_pem
except Exception as e:
logger.error(f"Key validation failed: {str(e)}")
raise ValueError(f"Invalid private key format") from e
except Exception as e:
logger.error(f"VAPID key loading failed: {str(e)}")
raise
def load_subscriptions(self) -> List[Dict]:
"""Load web push subscriptions from file."""
try:
with open(self.subscriptions_file, 'r') as f:
# Handle empty file case
content = f.read().strip()
return json.loads(content) if content else []
except json.JSONDecodeError:
logger.error(f"Error decoding subscriptions from {self.subscriptions_file}")
return []
def save_subscriptions(self):
"""Save web push subscriptions to file."""
try:
with open(self.subscriptions_file, 'w') as f:
json.dump(self.subscriptions, f, indent=2)
except Exception as e:
logger.error(f"Error saving subscriptions: {e}")
async def send_push_notification(self, subscription: Dict, message: str):
"""Send a web push notification."""
try:
if not self.subscriptions:
logger.warning("No subscriptions available")
return
webpush(
subscription_info=subscription,
data=message,
vapid_private_key=self.vapid_private_key,
vapid_claims={"sub": "mailto:your-email@example.com"}
)
except WebPushException as e:
logger.error(f"Push notification error: {e}")
# Remove invalid subscription
self.subscriptions = [s for s in self.subscriptions if s != subscription]
self.save_subscriptions()
async def handle_button1(self):
"""Handle first button action - e.g., Home Lights On"""
logger.info("Button 1 pressed: Home Lights On")
message = json.dumps({"action": "home_lights_on"})
await self.broadcast_notification(message)
async def handle_button2(self):
"""Handle second button action - e.g., Security System Arm"""
logger.info("Button 2 pressed: Security System Arm")
message = json.dumps({"action": "security_arm"})
await self.broadcast_notification(message)
async def handle_button3(self):
"""Handle third button action - e.g., Panic Button"""
logger.info("Button 3 pressed: Panic Alert")
message = json.dumps({"action": "panic_alert"})
await self.broadcast_notification(message)
async def broadcast_notification(self, message: str):
"""Broadcast notification to all subscriptions."""
if not self.subscriptions:
logger.warning("No subscriptions to broadcast to")
return
tasks = [
self.send_push_notification(subscription, message)
for subscription in self.subscriptions
]
await asyncio.gather(*tasks)
async def handle_flic_webhook(self, request):
"""Webhook endpoint for Flic button events."""
try:
data = await request.json()
button_serial = data.get('serial')
# Validate button serial
if button_serial not in self.button_configs:
logger.warning(f"Unknown button serial: {button_serial}")
return web.Response(status=400)
# Call the corresponding button handler
handler = self.button_configs[button_serial]
await handler()
return web.Response(status=200)
except Exception as e:
logger.error(f"Error processing Flic webhook: {e}")
return web.Response(status=500)
async def handle_subscribe(self, request):
"""Add a new web push subscription."""
try:
subscription = await request.json()
# Check if subscription already exists
if subscription not in self.subscriptions:
self.subscriptions.append(subscription)
self.save_subscriptions()
logger.info("New subscription added")
return web.Response(status=200)
except Exception as e:
logger.error(f"Subscription error: {e}")
return web.Response(status=500)
def create_app():
"""Create and configure the aiohttp application."""
app = web.Application()
handler = FlicButtonHandler()
async def options_handler(request):
"""Handle OPTIONS requests for CORS preflight."""
origin = request.headers.get('Origin', '')
if origin in ALLOWED_ORIGINS:
headers = {
'Access-Control-Allow-Origin': origin,
'Access-Control-Allow-Methods': ', '.join(ALLOWED_METHODS),
'Access-Control-Allow-Headers': ', '.join(ALLOWED_HEADERS),
'Access-Control-Max-Age': '86400', # 24 hours
}
return web.Response(status=200, headers=headers)
return web.Response(status=403) # Forbidden origin
async def add_cors_headers(request, response):
"""Add CORS headers to normal responses."""
origin = request.headers.get('Origin', '')
if origin in ALLOWED_ORIGINS:
response.headers['Access-Control-Allow-Origin'] = origin
response.headers['Access-Control-Expose-Headers'] = 'Content-Type'
return response
# Register middleware
app.on_response_prepare.append(add_cors_headers)
# Setup routes with OPTIONS handlers
app.router.add_route('OPTIONS', '/flic-webhook', options_handler)
app.router.add_route('OPTIONS', '/subscribe', options_handler)
# Original routes
app.router.add_post('/flic-webhook', handler.handle_flic_webhook)
app.router.add_post('/subscribe', handler.handle_subscribe)
return app
async def main():
"""Main application entry point."""
app = create_app()
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', 8080)
await site.start()
logger.info("Application started on port 8080")
# Create an event to keep the application running
stop_event = asyncio.Event()
def signal_handler():
"""Handle interrupt signals to gracefully stop the application."""
logger.info("Received shutdown signal")
stop_event.set()
# Register signal handlers
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, signal_handler)
# Wait until stop event is set
await stop_event.wait()
# Cleanup
await runner.cleanup()
logger.info("Application shutting down")
if __name__ == '__main__':
asyncio.run(main())