diff --git a/app.py b/app.py index 2df65cb..436a491 100644 --- a/app.py +++ b/app.py @@ -69,48 +69,36 @@ class FlicButtonHandler: raise def _decode_vapid_private_key(self): - """ - Final robust VAPID private key loader - Handles all possible key formats and provides detailed debugging - """ + """Load and validate VAPID private key with strict formatting.""" try: - # Get and clean the key from environment env_key = os.getenv('VAPID_PRIVATE_KEY', '').strip().strip('"\'') - - # Debug output - logger.debug(f"Raw env key length: {len(env_key)}") - logger.debug(f"Key starts with: {env_key[:50]}") - - # Convert to PEM format + + # Convert to consistent PEM format if '\\n' in env_key: - # Handle escaped newlines (from .env file) private_pem = env_key.replace('\\n', '\n') - elif '-----BEGIN PRIVATE KEY-----' in env_key: - # Already in PEM format - private_pem = env_key else: - # Assume base64 encoded - private_pem = base64.urlsafe_b64decode(env_key).decode('utf-8') - - # Ensure proper PEM format + private_pem = env_key + + # Ensure proper PEM headers if not private_pem.startswith('-----BEGIN PRIVATE KEY-----'): private_pem = f"-----BEGIN PRIVATE KEY-----\n{private_pem}\n-----END PRIVATE KEY-----" + + # Validate by loading + key = serialization.load_pem_private_key( + private_pem.encode('utf-8'), + password=None + ) + + # Return standardized PEM + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode('utf-8') - # Final validation - try: - key = serialization.load_pem_private_key( - private_pem.encode('utf-8'), - password=None - ) - logger.debug("VAPID private key successfully loaded") - return private_pem - except Exception as e: - logger.error(f"Key validation failed: {str(e)}") - raise ValueError(f"Invalid private key format") from e - except Exception as e: logger.error(f"VAPID key loading failed: {str(e)}") - raise + raise ValueError("Invalid VAPID private key format") from e def load_subscriptions(self) -> List[Dict]: """Load web push subscriptions from file.""" @@ -132,32 +120,50 @@ class FlicButtonHandler: logger.error(f"Error saving subscriptions: {e}") async def send_push_notification(self, subscription: Dict, message: str): - """Send a web push notification with proper key handling.""" + """Send a web push notification with robust key handling.""" try: if not self.subscriptions: logger.warning("No subscriptions available") return - # Debug output - logger.debug(f"Using VAPID key: {self.vapid_private_key[:50]}...") - logger.debug(f"Subscription endpoint: {subscription['endpoint'][:50]}...") + # Convert PEM key to bytes right before use + try: + private_key = serialization.load_pem_private_key( + self.vapid_private_key.encode('utf-8'), + password=None + ) + # Re-serialize to ensure clean format + vapid_private_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode('utf-8') + except Exception as e: + logger.error(f"Key conversion failed: {str(e)}") + raise + + # Extract domain for aud claim + endpoint = subscription['endpoint'] + aud = endpoint.split('/fcm/send')[0] if 'fcm.googleapis.com' in endpoint else endpoint.split('/send')[0] + + logger.debug(f"Using aud: {aud}") + logger.debug(f"Key length: {len(vapid_private_key)}") webpush( subscription_info=subscription, data=message, - vapid_private_key=self.vapid_private_key, + vapid_private_key=vapid_private_key, vapid_claims={ "sub": os.getenv('VAPID_CLAIM_EMAIL', 'mailto:your-email@example.com'), - "aud": subscription['endpoint'].split('/send/')[0] + "/" + "aud": aud } ) logger.info("Push notification sent successfully") except WebPushException as e: - logger.error(f"Push notification error: {str(e)}") + logger.error(f"Push failed: {str(e)}") if 'Invalid JWT' in str(e): - logger.error("VAPID key validation failed - check key format") - self.subscriptions = [s for s in self.subscriptions if s != subscription] - self.save_subscriptions() + logger.error("VAPID key validation failed") + raise async def handle_button1(self): """Handle first button action - e.g., Home Lights On"""