273 lines
9.5 KiB
Python
273 lines
9.5 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import base64
|
|
from typing import Dict, List
|
|
import signal
|
|
import pathlib
|
|
|
|
import aiohttp
|
|
from aiohttp import web
|
|
from dotenv import load_dotenv
|
|
from pywebpush import webpush, WebPushException
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, os.getenv('LOG_LEVEL', 'INFO')),
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# CORS Configuration
|
|
ALLOWED_ORIGINS = [
|
|
"https://game-timer.virtonline.eu",
|
|
# Add other allowed origins if needed
|
|
]
|
|
ALLOWED_METHODS = ["POST", "OPTIONS"]
|
|
ALLOWED_HEADERS = ["Content-Type"]
|
|
|
|
class FlicButtonHandler:
|
|
def __init__(self):
|
|
# Load button configurations
|
|
self.button_configs = {
|
|
os.getenv('FLIC_BUTTON1_SERIAL'): self.handle_button1,
|
|
os.getenv('FLIC_BUTTON2_SERIAL'): self.handle_button2,
|
|
os.getenv('FLIC_BUTTON3_SERIAL'): self.handle_button3
|
|
}
|
|
|
|
# Ensure subscriptions file and directory exist
|
|
self.subscriptions_file = os.getenv('SUBSCRIPTIONS_FILE', 'app/subscriptions.json')
|
|
self._ensure_subscriptions_file()
|
|
|
|
# Load subscriptions
|
|
self.subscriptions = self.load_subscriptions()
|
|
|
|
# Prepare VAPID keys
|
|
self.vapid_private_key = self._decode_vapid_private_key()
|
|
|
|
def _ensure_subscriptions_file(self):
|
|
"""
|
|
Ensure the subscriptions file and its parent directory exist.
|
|
Create them if they don't.
|
|
"""
|
|
try:
|
|
# Create parent directory if it doesn't exist
|
|
pathlib.Path(self.subscriptions_file).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create file if it doesn't exist
|
|
if not os.path.exists(self.subscriptions_file):
|
|
with open(self.subscriptions_file, 'w') as f:
|
|
json.dump([], f)
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring subscriptions file: {e}")
|
|
raise
|
|
|
|
def _decode_vapid_private_key(self):
|
|
"""
|
|
Load the VAPID private key from environment variable.
|
|
Handles the \n escaped format from .env file.
|
|
"""
|
|
try:
|
|
# Get the key from environment
|
|
env_key = os.getenv('VAPID_PRIVATE_KEY', '').strip()
|
|
|
|
# Convert escaped newlines back to actual newlines
|
|
private_pem = env_key.replace('\\n', '\n')
|
|
|
|
# Verify PEM format
|
|
if not private_pem.startswith('-----BEGIN PRIVATE KEY-----'):
|
|
raise ValueError("Invalid PEM format")
|
|
|
|
# Validate the key
|
|
serialization.load_pem_private_key(
|
|
private_pem.encode('utf-8'),
|
|
password=None
|
|
)
|
|
|
|
return private_pem
|
|
|
|
except Exception as e:
|
|
logger.error(f"VAPID key error: {str(e)}")
|
|
raise
|
|
|
|
def load_subscriptions(self) -> List[Dict]:
|
|
"""Load web push subscriptions from file."""
|
|
try:
|
|
with open(self.subscriptions_file, 'r') as f:
|
|
# Handle empty file case
|
|
content = f.read().strip()
|
|
return json.loads(content) if content else []
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Error decoding subscriptions from {self.subscriptions_file}")
|
|
return []
|
|
|
|
def save_subscriptions(self):
|
|
"""Save web push subscriptions to file."""
|
|
try:
|
|
with open(self.subscriptions_file, 'w') as f:
|
|
json.dump(self.subscriptions, f, indent=2)
|
|
except Exception as e:
|
|
logger.error(f"Error saving subscriptions: {e}")
|
|
|
|
async def send_push_notification(self, subscription: Dict, message: str):
|
|
"""Send a web push notification."""
|
|
try:
|
|
if not self.subscriptions:
|
|
logger.warning("No subscriptions available")
|
|
return
|
|
|
|
webpush(
|
|
subscription_info=subscription,
|
|
data=message,
|
|
vapid_private_key=self.vapid_private_key,
|
|
vapid_claims={"sub": "mailto:your-email@example.com"}
|
|
)
|
|
except WebPushException as e:
|
|
logger.error(f"Push notification error: {e}")
|
|
# Remove invalid subscription
|
|
self.subscriptions = [s for s in self.subscriptions if s != subscription]
|
|
self.save_subscriptions()
|
|
|
|
async def handle_button1(self):
|
|
"""Handle first button action - e.g., Home Lights On"""
|
|
logger.info("Button 1 pressed: Home Lights On")
|
|
message = json.dumps({"action": "home_lights_on"})
|
|
await self.broadcast_notification(message)
|
|
|
|
async def handle_button2(self):
|
|
"""Handle second button action - e.g., Security System Arm"""
|
|
logger.info("Button 2 pressed: Security System Arm")
|
|
message = json.dumps({"action": "security_arm"})
|
|
await self.broadcast_notification(message)
|
|
|
|
async def handle_button3(self):
|
|
"""Handle third button action - e.g., Panic Button"""
|
|
logger.info("Button 3 pressed: Panic Alert")
|
|
message = json.dumps({"action": "panic_alert"})
|
|
await self.broadcast_notification(message)
|
|
|
|
async def broadcast_notification(self, message: str):
|
|
"""Broadcast notification to all subscriptions."""
|
|
if not self.subscriptions:
|
|
logger.warning("No subscriptions to broadcast to")
|
|
return
|
|
|
|
tasks = [
|
|
self.send_push_notification(subscription, message)
|
|
for subscription in self.subscriptions
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
|
|
async def handle_flic_webhook(self, request):
|
|
"""Webhook endpoint for Flic button events."""
|
|
try:
|
|
data = await request.json()
|
|
button_serial = data.get('serial')
|
|
|
|
# Validate button serial
|
|
if button_serial not in self.button_configs:
|
|
logger.warning(f"Unknown button serial: {button_serial}")
|
|
return web.Response(status=400)
|
|
|
|
# Call the corresponding button handler
|
|
handler = self.button_configs[button_serial]
|
|
await handler()
|
|
|
|
return web.Response(status=200)
|
|
except Exception as e:
|
|
logger.error(f"Error processing Flic webhook: {e}")
|
|
return web.Response(status=500)
|
|
|
|
async def handle_subscribe(self, request):
|
|
"""Add a new web push subscription."""
|
|
try:
|
|
subscription = await request.json()
|
|
|
|
# Check if subscription already exists
|
|
if subscription not in self.subscriptions:
|
|
self.subscriptions.append(subscription)
|
|
self.save_subscriptions()
|
|
logger.info("New subscription added")
|
|
|
|
return web.Response(status=200)
|
|
except Exception as e:
|
|
logger.error(f"Subscription error: {e}")
|
|
return web.Response(status=500)
|
|
|
|
def create_app():
|
|
"""Create and configure the aiohttp application."""
|
|
app = web.Application()
|
|
handler = FlicButtonHandler()
|
|
|
|
async def options_handler(request):
|
|
"""Handle OPTIONS requests for CORS preflight."""
|
|
origin = request.headers.get('Origin', '')
|
|
if origin in ALLOWED_ORIGINS:
|
|
headers = {
|
|
'Access-Control-Allow-Origin': origin,
|
|
'Access-Control-Allow-Methods': ', '.join(ALLOWED_METHODS),
|
|
'Access-Control-Allow-Headers': ', '.join(ALLOWED_HEADERS),
|
|
'Access-Control-Max-Age': '86400', # 24 hours
|
|
}
|
|
return web.Response(status=200, headers=headers)
|
|
return web.Response(status=403) # Forbidden origin
|
|
|
|
async def add_cors_headers(request, response):
|
|
"""Add CORS headers to normal responses."""
|
|
origin = request.headers.get('Origin', '')
|
|
if origin in ALLOWED_ORIGINS:
|
|
response.headers['Access-Control-Allow-Origin'] = origin
|
|
response.headers['Access-Control-Expose-Headers'] = 'Content-Type'
|
|
return response
|
|
|
|
# Register middleware
|
|
app.on_response_prepare.append(add_cors_headers)
|
|
|
|
# Setup routes with OPTIONS handlers
|
|
app.router.add_route('OPTIONS', '/flic-webhook', options_handler)
|
|
app.router.add_route('OPTIONS', '/subscribe', options_handler)
|
|
|
|
# Original routes
|
|
app.router.add_post('/flic-webhook', handler.handle_flic_webhook)
|
|
app.router.add_post('/subscribe', handler.handle_subscribe)
|
|
|
|
return app
|
|
|
|
async def main():
|
|
"""Main application entry point."""
|
|
app = create_app()
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, '0.0.0.0', 8080)
|
|
await site.start()
|
|
|
|
logger.info("Application started on port 8080")
|
|
|
|
# Create an event to keep the application running
|
|
stop_event = asyncio.Event()
|
|
|
|
def signal_handler():
|
|
"""Handle interrupt signals to gracefully stop the application."""
|
|
logger.info("Received shutdown signal")
|
|
stop_event.set()
|
|
|
|
# Register signal handlers
|
|
loop = asyncio.get_running_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
# Wait until stop event is set
|
|
await stop_event.wait()
|
|
|
|
# Cleanup
|
|
await runner.cleanup()
|
|
logger.info("Application shutting down")
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main()) |