diff --git a/src/extension_shield/api/main.py b/src/extension_shield/api/main.py index 3ebbc4c..b0de13f 100644 --- a/src/extension_shield/api/main.py +++ b/src/extension_shield/api/main.py @@ -62,6 +62,14 @@ # Initialize logger logger = logging.getLogger(__name__) + +def _parse_trusted_proxy_hosts() -> list[str]: + """Return the explicit proxy hosts allowed to send forwarded headers.""" + raw_hosts = os.getenv("TRUSTED_PROXY_HOSTS", "").strip() + if raw_hosts: + return [host.strip() for host in raw_hosts.split(",") if host.strip()] + return ["127.0.0.1", "localhost", "::1"] + # Import safe JSON utilities from shared module from extension_shield.utils.json_encoder import ( safe_json_dumps, @@ -361,8 +369,9 @@ async def add_security_headers(request: Request, call_next): print(f"✅ CSP: Production mode detected (STATIC_DIR={STATIC_DIR}, index.html exists)") app.add_middleware(CSPMiddleware, is_dev=_is_dev) -# Trust X-Forwarded-Proto / X-Forwarded-For from Railway/Cloudflare so request.url.scheme is correct -app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*") +# Trust forwarded headers only from explicitly allowed proxy hosts. +# Set TRUSTED_PROXY_HOSTS to your actual reverse proxy / CDN hop(s). +app.add_middleware(ProxyHeadersMiddleware, trusted_hosts=_parse_trusted_proxy_hosts()) # In-memory state lives in shared.py; import references here so existing # code in this file (and tests) can continue using module-level names. @@ -408,20 +417,9 @@ def _get_client_ip(request: Request) -> str: """ Get the client's IP address for rate limiting anonymous users. - Handles proxied requests via X-Forwarded-For and X-Real-IP headers. - Falls back to client host if no headers present. + Relies on ProxyHeadersMiddleware to rewrite request.client only when the + request came from a trusted proxy host. """ - # Check X-Forwarded-For header (from reverse proxy/load balancer) - x_forwarded_for = request.headers.get("x-forwarded-for") - if x_forwarded_for: - # Take the first IP (original client) - return x_forwarded_for.split(",")[0].strip() - - # Check X-Real-IP header (from nginx) - x_real_ip = request.headers.get("x-real-ip") - if x_real_ip: - return x_real_ip.strip() - # Fall back to direct client IP if request.client: return request.client.host @@ -504,6 +502,27 @@ def _require_admin_or_telemetry_key(request: Request) -> None: ) +def _require_private_scan_artifact_access( + request: Request, + extension_id: str, + payload: Optional[Dict[str, Any]] = None, +) -> None: + """Block access to private scan artifacts unless the requester owns the scan.""" + requester_id = getattr(getattr(request, "state", None), "user_id", None) + if isinstance(payload, dict): + is_private = payload.get("visibility") == "private" or payload.get("source") == "upload" + owner_id = payload.get("user_id") or scan_user_ids.get(extension_id) + else: + is_private = scan_source.get(extension_id) == "upload" + owner_id = scan_user_ids.get(extension_id) + + if not is_private: + return + + if not requester_id or not owner_id or requester_id != owner_id: + raise HTTPException(status_code=404, detail="Scan results not found") + + def _deep_scan_limit_status(rate_limit_key: str) -> Dict[str, Any]: """Get deep scan limit status. Returns unlimited in local/dev environments. Anonymous (IP-based) users get 1 scan per day; authenticated users get 3. @@ -2888,7 +2907,7 @@ async def batch_scan_status(req: BatchStatusRequest, request: Request): @app.get("/api/scan/enforcement_bundle/{extension_id}") -async def get_enforcement_bundle(extension_id: str): +async def get_enforcement_bundle(extension_id: str, http_request: Request): """ Get the governance enforcement bundle for an analyzed extension. @@ -2926,6 +2945,8 @@ async def get_enforcement_bundle(extension_id: str): if not results: raise HTTPException(status_code=404, detail="Scan results not found") + + _require_private_scan_artifact_access(http_request, extension_id, results) # Check if governance analysis was run governance_bundle = results.get("governance_bundle") @@ -2959,7 +2980,7 @@ async def get_enforcement_bundle(extension_id: str): @app.get("/api/scan/report/{extension_id}") -async def generate_pdf_report(extension_id: str) -> Response: +async def generate_pdf_report(extension_id: str, http_request: Request) -> Response: """ Generate a PDF security report for an analyzed extension. @@ -2989,6 +3010,8 @@ async def generate_pdf_report(extension_id: str) -> Response: if not results: raise HTTPException(status_code=404, detail="Scan results not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + # Generate PDF report try: report_generator = ReportGenerator() @@ -3038,6 +3061,8 @@ async def get_file_list(extension_id: str, http_request: Request) -> FileListRes if not results: raise HTTPException(status_code=404, detail="Extension not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + extracted_path = results.get("extracted_path") if not extracted_path or not os.path.exists(extracted_path): raise HTTPException(status_code=404, detail="Extracted files not found") @@ -3069,22 +3094,43 @@ async def get_file_content(extension_id: str, file_path: str, http_request: Requ if not results: raise HTTPException(status_code=404, detail="Extension not found") + _require_private_scan_artifact_access(http_request, extension_id, results) + extracted_path = results.get("extracted_path") if not extracted_path: raise HTTPException(status_code=404, detail="Extracted files not found") - # Construct full file path - full_path = os.path.join(extracted_path, file_path) + # Resolve extracted root path (handles relative paths stored in DB). + extracted_root = Path(extracted_path) + if not extracted_root.is_absolute(): + extracted_root = Path(get_settings().extension_storage_path) / extracted_root + + try: + extracted_root = extracted_root.resolve(strict=True) + except Exception: + raise HTTPException(status_code=404, detail="Extracted files not found") + + if not extracted_root.is_dir(): + raise HTTPException(status_code=404, detail="Extracted files not found") + + # Resolve requested path and ensure it stays inside extracted root. + try: + candidate_path = (extracted_root / file_path).resolve(strict=True) + except Exception: + raise HTTPException(status_code=404, detail="File not found") - # Security check: ensure path is within extracted directory - if not os.path.abspath(full_path).startswith(os.path.abspath(extracted_path)): + try: + in_root = os.path.commonpath([str(extracted_root), str(candidate_path)]) == str(extracted_root) + except Exception: + in_root = False + if not in_root: raise HTTPException(status_code=403, detail="Access denied") - if not os.path.exists(full_path): + if not candidate_path.is_file(): raise HTTPException(status_code=404, detail="File not found") try: - with open(full_path, "r", encoding="utf-8") as f: + with open(candidate_path, "r", encoding="utf-8") as f: content = f.read() return FileContentResponse(content=content, file_path=file_path) except UnicodeDecodeError as exc: @@ -3824,7 +3870,7 @@ async def database_health_check(request: Request): @app.get("/api/scan/icon/{extension_id}") -async def get_extension_icon(extension_id: str): +async def get_extension_icon(extension_id: str, http_request: Request): """ Get extension icon from the extracted extension folder. Uses icon_path from storage when available, and falls back to persisted icon bytes. @@ -3850,6 +3896,9 @@ async def get_extension_icon(extension_id: str): icon_media_type = results.get("icon_media_type") else: db_icon_record = _load_icon_record_from_db(extension_id) + results = db.get_scan_result(extension_id) + if results: + scan_results[extension_id] = results extracted_path = db_icon_record.get("extracted_path") icon_path = db_icon_record.get("icon_path") icon_base64 = db_icon_record.get("icon_base64") @@ -3893,6 +3942,7 @@ async def get_extension_icon(extension_id: str): # Best practice: if we have a persisted icon blob, serve it immediately. # This avoids relying on filesystem state (ephemeral/persistent) and prevents slow fallbacks. + _require_private_scan_artifact_access(http_request, extension_id, results) persisted = _extension_icon_response_from_base64(icon_base64, icon_media_type) if persisted: return persisted diff --git a/tests/api/test_enforcement_bundle.py b/tests/api/test_enforcement_bundle.py index 0b1f90a..174e090 100644 --- a/tests/api/test_enforcement_bundle.py +++ b/tests/api/test_enforcement_bundle.py @@ -118,6 +118,38 @@ def test_get_enforcement_bundle_not_found(self, client): assert response.status_code == 404 assert "not found" in response.json()["detail"].lower() + + def test_private_scan_artifact_requires_owner(self, client, tmp_path): + """Private scan artifacts should not be readable without the owning user.""" + ext_id = "privateartifact1234567890123456" + extracted = tmp_path / ext_id + extracted.mkdir() + + scan_results[ext_id] = { + "extension_id": ext_id, + "extension_name": "Private Extension", + "status": "completed", + "visibility": "private", + "user_id": "owner-user-1", + "governance_bundle": {"decision": {"verdict": "ALLOW"}}, + "extracted_path": str(extracted), + } + + try: + endpoints = [ + f"/api/scan/enforcement_bundle/{ext_id}", + f"/api/scan/report/{ext_id}", + f"/api/scan/files/{ext_id}", + f"/api/scan/file/{ext_id}/manifest.json", + f"/api/scan/icon/{ext_id}", + ] + + for endpoint in endpoints: + response = client.get(endpoint) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + finally: + scan_results.pop(ext_id, None) def test_get_enforcement_bundle_no_governance_data(self, client): """Test 404 when governance bundle not available.""" diff --git a/tests/api/test_scan_results_endpoint.py b/tests/api/test_scan_results_endpoint.py index 192e6a4..9a535a2 100644 --- a/tests/api/test_scan_results_endpoint.py +++ b/tests/api/test_scan_results_endpoint.py @@ -95,3 +95,33 @@ def test_legacy_payload_is_upgraded_with_consumer_insights(self, client: TestCli assert isinstance(rvm["consumer_insights"], dict) +def test_scan_file_blocks_path_traversal(client: TestClient, tmp_path) -> None: + """/api/scan/file should block traversal outside extracted root.""" + extension_id = "pathtraversaltest1234567890123456" + + extracted_dir = tmp_path / "extracted" + extracted_dir.mkdir() + inside_file = extracted_dir / "manifest.json" + inside_file.write_text('{"name": "safe"}', encoding="utf-8") + + outside_file = tmp_path / "outside.txt" + outside_file.write_text("secret", encoding="utf-8") + + scan_results[extension_id] = { + "extension_id": extension_id, + "status": "completed", + "visibility": "public", + "extracted_path": str(extracted_dir), + } + + try: + ok = client.get(f"/api/scan/file/{extension_id}/manifest.json") + assert ok.status_code == 200 + assert "safe" in ok.json()["content"] + + blocked = client.get(f"/api/scan/file/{extension_id}/../outside.txt") + assert blocked.status_code == 403 + finally: + scan_results.pop(extension_id, None) + +