Harden auth/upload, fix PR-delete cascade and sync backfill
- OIDC: require signed short-lived state on login callback; reject missing userinfo sub (account-takeover guard); validate token exchange + userinfo responses - Upload: safe zip extraction (path-traversal + zip-bomb cap), streamed size-capped writes, sanitised filenames - Garmin: increasing lookback resets last_sync_at for one-time backfill - Activities: delete/reprocess remove PersonalRecord rows (no FK cascade) - Profile: validate /weight limit; sync lookback UI copy - Dashboard: sleep shading uses same day as charted body battery Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import get_current_user
|
||||
from app.models.user import User, Activity, ActivityDataPoint, ActivityLap
|
||||
from app.models.user import User, Activity, ActivityDataPoint, ActivityLap, PersonalRecord
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -266,6 +266,10 @@ async def delete_activity(
|
||||
activity = result.scalar_one_or_none()
|
||||
if not activity:
|
||||
raise HTTPException(status_code=404, detail="Activity not found")
|
||||
# PersonalRecord.activity_id has no cascade, so remove the activity's PR rows
|
||||
# first or the delete fails the FK constraint. (segment_efforts cascade in DB;
|
||||
# data_points/laps cascade via the ORM relationship.)
|
||||
await db.execute(delete(PersonalRecord).where(PersonalRecord.activity_id == activity_id))
|
||||
await db.delete(activity)
|
||||
await db.commit()
|
||||
|
||||
@@ -297,6 +301,8 @@ async def reprocess_activity(
|
||||
|
||||
await db.execute(delete(ActivityDataPoint).where(ActivityDataPoint.activity_id == activity_id))
|
||||
await db.execute(delete(ActivityLap).where(ActivityLap.activity_id == activity_id))
|
||||
# Drop PR rows referencing this activity (no cascade); the re-parse re-computes them.
|
||||
await db.execute(delete(PersonalRecord).where(PersonalRecord.activity_id == activity_id))
|
||||
await db.delete(activity)
|
||||
await db.commit()
|
||||
|
||||
|
||||
+38
-1
@@ -19,6 +19,7 @@ router = APIRouter()
|
||||
# to a normal sign-in), so the callback attaches the passkey to a known user
|
||||
# instead of creating/looking-up by identity.
|
||||
LINK_STATE_PURPOSE = "pocketid-link"
|
||||
LOGIN_STATE_PURPOSE = "pocketid-login"
|
||||
|
||||
|
||||
def _make_link_state(user_id: int) -> str:
|
||||
@@ -29,6 +30,25 @@ def _make_link_state(user_id: int) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _make_login_state() -> str:
|
||||
"""Signed, short-lived CSRF token proving the login flow started from this app."""
|
||||
return create_access_token(
|
||||
{"sub": "login", "purpose": LOGIN_STATE_PURPOSE},
|
||||
expires_delta=timedelta(minutes=10),
|
||||
)
|
||||
|
||||
|
||||
def _valid_login_state(state: Optional[str]) -> bool:
|
||||
"""True if `state` is a valid, unexpired login-state token we issued."""
|
||||
if not state:
|
||||
return False
|
||||
try:
|
||||
payload = jwt.decode(state, settings.secret_key, algorithms=[settings.algorithm])
|
||||
return payload.get("purpose") == LOGIN_STATE_PURPOSE
|
||||
except JWTError:
|
||||
return False
|
||||
|
||||
|
||||
def _decode_link_state(state: Optional[str]) -> Optional[int]:
|
||||
"""Return the user id from a valid link-state token, else None."""
|
||||
if not state:
|
||||
@@ -157,6 +177,7 @@ async def pocketid_login_url(db: AsyncSession = Depends(get_db)):
|
||||
"redirect_uri": f"{settings.base_url}/api/auth/pocketid/callback",
|
||||
"response_type": "code",
|
||||
"scope": "openid profile email groups",
|
||||
"state": _make_login_state(),
|
||||
}
|
||||
return {"url": f"{issuer}/authorize?{urlencode(params)}"}
|
||||
|
||||
@@ -202,10 +223,15 @@ async def pocketid_callback(code: str, state: Optional[str] = None, db: AsyncSes
|
||||
print(f"PocketID token exchange failed ({resp.status_code}): {resp.text}")
|
||||
raise HTTPException(status_code=400, detail="Token exchange failed")
|
||||
tokens = resp.json()
|
||||
access_token = tokens.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=400, detail="Token exchange failed")
|
||||
userinfo_resp = await client.get(
|
||||
f"{issuer}/api/oidc/userinfo",
|
||||
headers={"Authorization": f"Bearer {tokens['access_token']}"},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
if userinfo_resp.status_code != 200:
|
||||
raise HTTPException(status_code=400, detail="Failed to fetch user info")
|
||||
userinfo = userinfo_resp.json()
|
||||
|
||||
from fastapi.responses import RedirectResponse
|
||||
@@ -214,6 +240,12 @@ async def pocketid_callback(code: str, state: Optional[str] = None, db: AsyncSes
|
||||
email = userinfo.get("email")
|
||||
preferred_username = userinfo.get("preferred_username") or email
|
||||
|
||||
# A missing subject means we cannot identify the user. Never continue, or the
|
||||
# `pocketid_sub == sub` (== None → IS NULL) lookups below would match any
|
||||
# password-only account and log the caller in as someone else.
|
||||
if not sub:
|
||||
return RedirectResponse(url="/login?auth_error=no_identity")
|
||||
|
||||
# ── Explicit account-link flow ──────────────────────────────────────────
|
||||
# Initiated by an already-authenticated user from their profile. Attach the
|
||||
# passkey to that account. No group gating here: this is identity linking,
|
||||
@@ -238,6 +270,11 @@ async def pocketid_callback(code: str, state: Optional[str] = None, db: AsyncSes
|
||||
target.email = email
|
||||
return RedirectResponse(url="/profile?linked=1")
|
||||
|
||||
# Normal sign-in: require the signed, short-lived state we issued in
|
||||
# /pocketid/login-url, so the callback can't be driven by an injected code.
|
||||
if not _valid_login_state(state):
|
||||
return RedirectResponse(url="/login?auth_error=invalid_state")
|
||||
|
||||
# Group gating: if an allowed group is configured, the user must be in it.
|
||||
allowed_group = await _get_allowed_group(db)
|
||||
if allowed_group:
|
||||
|
||||
@@ -37,6 +37,17 @@ class GarminConfigOut(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
def _wants_more_history(old: int, new: int) -> bool:
|
||||
"""True if `new` lookback requests older data than `old` (-1 = all-time)."""
|
||||
if new == old:
|
||||
return False
|
||||
if new == -1: # all-time requested where it wasn't before
|
||||
return True
|
||||
if old == -1: # was all-time, now finite → narrower, not more
|
||||
return False
|
||||
return new > old
|
||||
|
||||
|
||||
@router.get("/config", response_model=GarminConfigOut)
|
||||
async def get_config(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -111,6 +122,16 @@ async def save_config(
|
||||
if not cfg:
|
||||
raise HTTPException(status_code=400, detail="No Garmin account connected — password required for first-time setup")
|
||||
|
||||
# If the user is now asking for MORE history than before, reset last_sync_at so
|
||||
# the next sync treats it as a first sync and does a one-time backfill of the
|
||||
# wider lookback window (then resumes cheap incremental syncs). Scheduled syncs
|
||||
# otherwise only refresh the last day or two, so without this an increased
|
||||
# lookback would never actually fetch the older data.
|
||||
old_lookback = cfg.sync_lookback_days if cfg.sync_lookback_days is not None else 30
|
||||
if _wants_more_history(old_lookback, body.sync_lookback_days):
|
||||
cfg.last_sync_at = None
|
||||
cfg.last_sync_status = "Lookback increased — backfill on next sync"
|
||||
|
||||
cfg.sync_enabled = body.sync_enabled
|
||||
cfg.sync_activities = body.sync_activities
|
||||
cfg.sync_wellness = body.sync_wellness
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, desc
|
||||
from pydantic import BaseModel
|
||||
@@ -192,7 +192,7 @@ class WeightOut(BaseModel):
|
||||
|
||||
@router.get("/weight", response_model=List[WeightOut])
|
||||
async def list_weight(
|
||||
limit: int = 365,
|
||||
limit: int = Query(365, ge=1, le=2000),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
+82
-39
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, BackgroundTasks
|
||||
@@ -13,18 +12,68 @@ from app.workers.tasks import process_activity_file, process_garmin_health_zip
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ALLOWED_EXTENSIONS = {".fit", ".gpx", ".zip"}
|
||||
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB
|
||||
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500 MB upload cap
|
||||
MAX_EXTRACT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GB total uncompressed cap (zip-bomb guard)
|
||||
_CHUNK = 1024 * 1024
|
||||
|
||||
|
||||
def _safe_name(filename: str) -> str:
|
||||
"""Reduce an uploaded filename to a safe basename — no path traversal."""
|
||||
name = os.path.basename((filename or "").replace("\\", "/"))
|
||||
if not name or name in (".", ".."):
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
return name
|
||||
|
||||
|
||||
def save_upload(upload: UploadFile, dest_dir: Path) -> Path:
|
||||
"""Stream an upload to disk under dest_dir, enforcing the size cap."""
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest = dest_dir / upload.filename
|
||||
dest = dest_dir / _safe_name(upload.filename)
|
||||
size = 0
|
||||
with open(dest, "wb") as f:
|
||||
shutil.copyfileobj(upload.file, f)
|
||||
while True:
|
||||
chunk = upload.file.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
if size > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
dest.unlink(missing_ok=True)
|
||||
raise HTTPException(status_code=413, detail="File exceeds the 500 MB limit")
|
||||
f.write(chunk)
|
||||
return dest
|
||||
|
||||
|
||||
def _safe_extract(zf: zipfile.ZipFile, dest_dir: Path) -> list[Path]:
|
||||
"""Extract a zip safely: skip path-traversal members, cap total uncompressed
|
||||
bytes (zip-bomb guard). Returns the list of extracted regular-file paths."""
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_root = dest_dir.resolve()
|
||||
total = 0
|
||||
extracted: list[Path] = []
|
||||
for info in zf.infolist():
|
||||
if info.is_dir():
|
||||
continue
|
||||
target = (dest_root / info.filename).resolve()
|
||||
# Reject absolute paths and ../ traversal: the target must stay under dest_root.
|
||||
if target != dest_root and dest_root not in target.parents:
|
||||
continue
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(info) as src, open(target, "wb") as out:
|
||||
while True:
|
||||
chunk = src.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_EXTRACT_SIZE:
|
||||
out.close()
|
||||
target.unlink(missing_ok=True)
|
||||
raise HTTPException(status_code=413, detail="Archive expands beyond the size limit")
|
||||
out.write(chunk)
|
||||
extracted.append(target)
|
||||
return extracted
|
||||
|
||||
|
||||
@router.post("/activity")
|
||||
async def upload_activity(
|
||||
file: UploadFile = File(...),
|
||||
@@ -62,35 +111,31 @@ async def upload_garmin_export(
|
||||
dest_dir = Path(settings.file_store_path) / str(current_user.id) / "exports"
|
||||
dest = save_upload(file, dest_dir)
|
||||
|
||||
# Extract and queue all FIT files
|
||||
# Extract (safely) and queue all FIT files
|
||||
extract_dir = dest_dir / f"garmin_{dest.stem}"
|
||||
extract_dir.mkdir(exist_ok=True)
|
||||
|
||||
task_ids = []
|
||||
with zipfile.ZipFile(dest) as zf:
|
||||
zf.extractall(extract_dir)
|
||||
for name in zf.namelist():
|
||||
lower = name.lower()
|
||||
if lower.endswith(".fit"):
|
||||
fit_path = extract_dir / name
|
||||
task = process_activity_file.delay(str(fit_path), current_user.id, "fit")
|
||||
task_ids.append(task.id)
|
||||
elif lower.endswith(".zip"):
|
||||
# Garmin exports nest activity FIT files inside sub-zips
|
||||
# (e.g. DI-Connect-Uploaded-Files/UploadedFiles_*_Part*.zip)
|
||||
nested_zip_path = extract_dir / name
|
||||
nested_extract = nested_zip_path.parent / nested_zip_path.stem
|
||||
nested_extract.mkdir(exist_ok=True)
|
||||
try:
|
||||
with zipfile.ZipFile(nested_zip_path) as nzf:
|
||||
nzf.extractall(nested_extract)
|
||||
for nested_name in nzf.namelist():
|
||||
if nested_name.lower().endswith(".fit"):
|
||||
fit_path = nested_extract / nested_name
|
||||
task = process_activity_file.delay(str(fit_path), current_user.id, "fit")
|
||||
task_ids.append(task.id)
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
extracted = _safe_extract(zf, extract_dir)
|
||||
|
||||
for path in extracted:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix == ".fit":
|
||||
task = process_activity_file.delay(str(path), current_user.id, "fit")
|
||||
task_ids.append(task.id)
|
||||
elif suffix == ".zip":
|
||||
# Garmin exports nest activity FIT files inside sub-zips
|
||||
# (e.g. DI-Connect-Uploaded-Files/UploadedFiles_*_Part*.zip)
|
||||
nested_extract = path.parent / path.stem
|
||||
try:
|
||||
with zipfile.ZipFile(path) as nzf:
|
||||
nested = _safe_extract(nzf, nested_extract)
|
||||
except zipfile.BadZipFile:
|
||||
nested = []
|
||||
for np in nested:
|
||||
if np.suffix.lower() == ".fit":
|
||||
task = process_activity_file.delay(str(np), current_user.id, "fit")
|
||||
task_ids.append(task.id)
|
||||
|
||||
# Queue health/wellness data extraction
|
||||
health_task = process_garmin_health_zip.delay(str(dest), current_user.id)
|
||||
@@ -116,18 +161,16 @@ async def upload_strava_export(
|
||||
dest = save_upload(file, dest_dir)
|
||||
|
||||
extract_dir = dest_dir / f"strava_{dest.stem}"
|
||||
extract_dir.mkdir(exist_ok=True)
|
||||
|
||||
task_ids = []
|
||||
with zipfile.ZipFile(dest) as zf:
|
||||
zf.extractall(extract_dir)
|
||||
for name in zf.namelist():
|
||||
lower = name.lower()
|
||||
if lower.endswith(".fit") or lower.endswith(".gpx"):
|
||||
file_path = extract_dir / name
|
||||
ext = Path(name).suffix[1:]
|
||||
task = process_activity_file.delay(str(file_path), current_user.id, ext)
|
||||
task_ids.append(task.id)
|
||||
extracted = _safe_extract(zf, extract_dir)
|
||||
|
||||
for path in extracted:
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in (".fit", ".gpx"):
|
||||
task = process_activity_file.delay(str(path), current_user.id, suffix[1:])
|
||||
task_ids.append(task.id)
|
||||
|
||||
return {
|
||||
"status": "queued",
|
||||
|
||||
Reference in New Issue
Block a user