Coverage for src / idx_api / routers / brokerage_domains.py: 40%

230 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2025-12-28 11:09 -0700

1"""Brokerage domain endpoints for managing white-label domain mappings.""" 

2 

3import dns.resolver 

4import dns.exception 

5from datetime import datetime, timezone 

6 

7from fastapi import APIRouter, Depends, HTTPException, Query 

8from pydantic import BaseModel 

9from sqlalchemy import func, select 

10from sqlalchemy.orm import Session 

11 

12from idx_api.auth import AdminUser, BrokerUser, RequiredUser 

13from idx_api.config import settings 

14from idx_api.database import get_db 

15from idx_api.dns_providers import get_dns_info, clear_dns_cache 

16from idx_api.models.brokerage_domain import BrokerageDomain 

17from idx_api.utils.cache import invalidate_site_config 

18 

19router = APIRouter() 

20 

21 

22# ===== Response Models ===== 

23 

24 

25class BrokerageDomainResponse(BaseModel): 

26 """Brokerage domain response model.""" 

27 

28 id: int 

29 brokerage_id: int 

30 domain: str 

31 is_primary: bool 

32 is_verified: bool 

33 verified_at: datetime | None 

34 notes: str | None 

35 created_at: datetime 

36 updated_at: datetime 

37 

38 class Config: 

39 from_attributes = True 

40 

41 

42class BrokerageDomainCreate(BaseModel): 

43 """Brokerage domain creation request.""" 

44 

45 brokerage_id: int 

46 domain: str 

47 is_primary: bool = False 

48 notes: str | None = None 

49 

50 

51class BrokerageDomainUpdate(BaseModel): 

52 """Brokerage domain update request.""" 

53 

54 domain: str | None = None 

55 is_primary: bool | None = None 

56 notes: str | None = None 

57 

58 

59class PaginatedDomains(BaseModel): 

60 """Paginated domain list response.""" 

61 

62 items: list[BrokerageDomainResponse] 

63 total: int 

64 page: int 

65 page_size: int 

66 pages: int 

67 

68 

69# ===== CRUD Endpoints ===== 

70 

71 

72@router.get("/brokerage-domains", response_model=PaginatedDomains) 

73async def list_brokerage_domains( 

74 user: RequiredUser, 

75 db: Session = Depends(get_db), 

76 brokerage_id: int | None = Query(None, description="Filter by brokerage ID"), 

77 page: int = Query(1, ge=1), 

78 page_size: int = Query(20, ge=1, le=100), 

79): 

80 """ 

81 List brokerage domains with pagination. 

82 

83 - Admins see all domains 

84 - Brokers see only domains for their brokerage 

85 - Can filter by brokerage_id 

86 """ 

87 # Build base query 

88 base_where = [] 

89 

90 # Filter by brokerage 

91 if brokerage_id: 

92 base_where.append(BrokerageDomain.brokerage_id == brokerage_id) 

93 elif user.role != "admin" and user.brokerage_id: 

94 # Non-admins can only see domains from their own brokerage 

95 base_where.append(BrokerageDomain.brokerage_id == user.brokerage_id) 

96 

97 # Count total 

98 count_stmt = select(func.count()).select_from(BrokerageDomain) 

99 if base_where: 

100 count_stmt = count_stmt.where(*base_where) 

101 total = db.scalar(count_stmt) or 0 

102 

103 # Get paginated results 

104 offset = (page - 1) * page_size 

105 query = ( 

106 select(BrokerageDomain) 

107 .order_by(BrokerageDomain.is_primary.desc(), BrokerageDomain.domain.asc()) 

108 .offset(offset) 

109 .limit(page_size) 

110 ) 

111 if base_where: 

112 query = query.where(*base_where) 

113 

114 domains = db.scalars(query).all() 

115 total_pages = (total + page_size - 1) // page_size if total > 0 else 1 

116 

117 return PaginatedDomains( 

118 items=[BrokerageDomainResponse.model_validate(d) for d in domains], 

119 total=total, 

120 page=page, 

121 page_size=page_size, 

122 pages=total_pages, 

123 ) 

124 

125 

126@router.get("/brokerage-domains/{domain_id}", response_model=BrokerageDomainResponse) 

127async def get_brokerage_domain( 

128 domain_id: int, 

129 user: RequiredUser, 

130 db: Session = Depends(get_db), 

131): 

132 """ 

133 Get a single brokerage domain by ID. 

134 

135 - Admins can access any domain 

136 - Non-admins can only access domains from their brokerage 

137 """ 

138 domain = db.get(BrokerageDomain, domain_id) 

139 if not domain: 

140 raise HTTPException(status_code=404, detail="Domain not found") 

141 

142 # Authorization check for non-admins 

143 if user.role != "admin" and user.brokerage_id != domain.brokerage_id: 

144 raise HTTPException(status_code=403, detail="Access denied") 

145 

146 return BrokerageDomainResponse.model_validate(domain) 

147 

148 

149@router.post("/brokerage-domains", response_model=BrokerageDomainResponse) 

150async def create_brokerage_domain( 

151 data: BrokerageDomainCreate, 

152 user: AdminUser, # Only admins can create domains (security implications) 

153 db: Session = Depends(get_db), 

154): 

155 """ 

156 Create a new brokerage domain mapping. 

157 

158 Requires admin role (domain mappings have security implications). 

159 

160 If is_primary=True, will unset is_primary on other domains for that brokerage. 

161 """ 

162 # Check if domain already exists 

163 existing = db.scalar( 

164 select(BrokerageDomain).where(BrokerageDomain.domain == data.domain.lower()) 

165 ) 

166 if existing: 

167 raise HTTPException( 

168 status_code=400, 

169 detail=f"Domain '{data.domain}' is already registered" 

170 ) 

171 

172 # If setting as primary, unset primary for other domains in this brokerage 

173 if data.is_primary: 

174 existing_primary = db.scalars( 

175 select(BrokerageDomain).where( 

176 BrokerageDomain.brokerage_id == data.brokerage_id, 

177 BrokerageDomain.is_primary == True, 

178 ) 

179 ).all() 

180 for dom in existing_primary: 

181 dom.is_primary = False 

182 dom.updated_at = datetime.now(timezone.utc) 

183 

184 now = datetime.now(timezone.utc) 

185 domain = BrokerageDomain( 

186 brokerage_id=data.brokerage_id, 

187 domain=data.domain.lower().strip(), 

188 is_primary=data.is_primary, 

189 notes=data.notes, 

190 created_at=now, 

191 updated_at=now, 

192 ) 

193 

194 # Insert with race condition handling for unique domain constraint 

195 try: 

196 db.add(domain) 

197 db.commit() 

198 db.refresh(domain) 

199 except Exception as e: 

200 db.rollback() 

201 # Check if it was a domain conflict 

202 if "UNIQUE constraint" in str(e) or "domain" in str(e).lower(): 

203 raise HTTPException( 

204 status_code=400, 

205 detail=f"Domain '{data.domain}' is already registered" 

206 ) 

207 raise 

208 

209 # Invalidate cache for this new domain 

210 await invalidate_site_config(domain.domain) 

211 

212 return BrokerageDomainResponse.model_validate(domain) 

213 

214 

215@router.put("/brokerage-domains/{domain_id}", response_model=BrokerageDomainResponse) 

216async def update_brokerage_domain( 

217 domain_id: int, 

218 data: BrokerageDomainUpdate, 

219 user: AdminUser, # Only admins can update domains 

220 db: Session = Depends(get_db), 

221): 

222 """ 

223 Update an existing brokerage domain. 

224 

225 Requires admin role. 

226 

227 If setting is_primary=True, will unset is_primary on other domains for that brokerage. 

228 """ 

229 domain = db.get(BrokerageDomain, domain_id) 

230 if not domain: 

231 raise HTTPException(status_code=404, detail="Domain not found") 

232 

233 # Update fields 

234 update_data = data.model_dump(exclude_unset=True) 

235 

236 # If changing domain name, check it's not taken 

237 if "domain" in update_data: 

238 new_domain = update_data["domain"].lower().strip() 

239 existing = db.scalar( 

240 select(BrokerageDomain).where( 

241 BrokerageDomain.domain == new_domain, 

242 BrokerageDomain.id != domain_id, 

243 ) 

244 ) 

245 if existing: 

246 raise HTTPException( 

247 status_code=400, 

248 detail=f"Domain '{new_domain}' is already registered" 

249 ) 

250 update_data["domain"] = new_domain 

251 # Changing domain resets verification 

252 domain.is_verified = False 

253 domain.verified_at = None 

254 

255 # If setting as primary, unset primary for other domains in this brokerage 

256 if update_data.get("is_primary"): 

257 existing_primary = db.scalars( 

258 select(BrokerageDomain).where( 

259 BrokerageDomain.brokerage_id == domain.brokerage_id, 

260 BrokerageDomain.is_primary == True, 

261 BrokerageDomain.id != domain_id, 

262 ) 

263 ).all() 

264 for other_domain in existing_primary: 

265 other_domain.is_primary = False 

266 other_domain.updated_at = datetime.now(timezone.utc) 

267 

268 for field, value in update_data.items(): 

269 setattr(domain, field, value) 

270 

271 domain.updated_at = datetime.now(timezone.utc) 

272 

273 db.commit() 

274 db.refresh(domain) 

275 

276 # Invalidate cache for this domain 

277 await invalidate_site_config(domain.domain) 

278 

279 return BrokerageDomainResponse.model_validate(domain) 

280 

281 

282@router.delete("/brokerage-domains/{domain_id}") 

283async def delete_brokerage_domain( 

284 domain_id: int, 

285 user: AdminUser, # Only admins can delete domains 

286 db: Session = Depends(get_db), 

287): 

288 """ 

289 Delete a brokerage domain. 

290 

291 Requires admin role. 

292 """ 

293 domain = db.get(BrokerageDomain, domain_id) 

294 if not domain: 

295 raise HTTPException(status_code=404, detail="Domain not found") 

296 

297 # Capture domain name before deletion 

298 domain_name = domain.domain 

299 

300 db.delete(domain) 

301 db.commit() 

302 

303 # Invalidate cache for the deleted domain 

304 await invalidate_site_config(domain_name) 

305 

306 return {"message": "Domain deleted successfully"} 

307 

308 

309class DomainVerificationResponse(BaseModel): 

310 """Domain verification response with DNS details.""" 

311 

312 domain: BrokerageDomainResponse 

313 dns_check: dict 

314 

315 

316def check_domain_cname(domain_name: str) -> dict: 

317 """ 

318 Verify that a domain has a CNAME record pointing to the expected target. 

319 

320 Returns a dict with: 

321 - valid: bool - whether the CNAME is correctly configured 

322 - cname_found: str | None - the CNAME target found 

323 - expected: str - what we're looking for 

324 - error: str | None - any error message 

325 """ 

326 expected_target = settings.domain_cname_target.rstrip('.') 

327 

328 result = { 

329 "valid": False, 

330 "cname_found": None, 

331 "expected": expected_target, 

332 "error": None, 

333 } 

334 

335 try: 

336 # Query DNS for CNAME record 

337 resolver = dns.resolver.Resolver() 

338 resolver.nameservers = ['8.8.8.8', '1.1.1.1'] # Use public DNS 

339 resolver.timeout = 5 

340 resolver.lifetime = 10 

341 

342 answers = resolver.resolve(domain_name, 'CNAME') 

343 

344 for rdata in answers: 

345 cname_target = str(rdata.target).rstrip('.') 

346 result["cname_found"] = cname_target 

347 

348 # Check if CNAME matches expected target (case-insensitive) 

349 if cname_target.lower() == expected_target.lower(): 

350 result["valid"] = True 

351 break 

352 

353 except dns.resolver.NXDOMAIN: 

354 result["error"] = f"Domain '{domain_name}' does not exist in DNS" 

355 except dns.resolver.NoAnswer: 

356 # No CNAME record - might be an A record pointing directly 

357 # Try to check if the domain resolves at all 

358 try: 

359 resolver.resolve(domain_name, 'A') 

360 result["error"] = f"Domain has A record but no CNAME. Please add: {domain_name} CNAME {expected_target}" 

361 except: 

362 result["error"] = f"No CNAME record found. Please add: {domain_name} CNAME {expected_target}" 

363 except dns.exception.Timeout: 

364 result["error"] = "DNS query timed out. Please try again." 

365 except Exception as e: 

366 result["error"] = f"DNS lookup failed: {str(e)}" 

367 

368 return result 

369 

370 

371@router.post("/brokerage-domains/{domain_id}/verify", response_model=DomainVerificationResponse) 

372async def verify_domain( 

373 domain_id: int, 

374 user: AdminUser, 

375 db: Session = Depends(get_db), 

376): 

377 """ 

378 Verify a domain's DNS configuration. 

379 

380 Checks that the domain has a CNAME record pointing to the expected 

381 target hostname (configured via DOMAIN_CNAME_TARGET env var). 

382 

383 If verification succeeds, marks the domain as verified. 

384 

385 Requires admin role. 

386 """ 

387 domain = db.get(BrokerageDomain, domain_id) 

388 if not domain: 

389 raise HTTPException(status_code=404, detail="Domain not found") 

390 

391 # Perform DNS verification 

392 dns_result = check_domain_cname(domain.domain) 

393 

394 if dns_result["valid"]: 

395 domain.is_verified = True 

396 domain.verified_at = datetime.now(timezone.utc) 

397 domain.updated_at = datetime.now(timezone.utc) 

398 db.commit() 

399 db.refresh(domain) 

400 

401 return DomainVerificationResponse( 

402 domain=BrokerageDomainResponse.model_validate(domain), 

403 dns_check=dns_result, 

404 ) 

405 

406 

407# ===== DNS Info Endpoint ===== 

408 

409 

410class DnsProviderInfo(BaseModel): 

411 """DNS provider information.""" 

412 name: str 

413 icon: str 

414 color: str 

415 

416 

417class MailProviderInfo(BaseModel): 

418 """Mail provider information.""" 

419 name: str 

420 icon: str 

421 color: str 

422 

423 

424class MxRecordInfo(BaseModel): 

425 """MX record with priority.""" 

426 priority: int 

427 host: str 

428 

429 

430class ServiceVerificationInfo(BaseModel): 

431 """Service verification detected in TXT records.""" 

432 service: str 

433 icon: str 

434 color: str 

435 

436 

437class DnsInfoResponse(BaseModel): 

438 """DNS information response for a domain.""" 

439 # DNS Provider 

440 nameservers: list[str] 

441 provider: DnsProviderInfo | None 

442 # Mail 

443 mx_records: list[MxRecordInfo] 

444 mail_provider: MailProviderInfo | None 

445 # TXT Records 

446 spf: str | None 

447 dmarc: str | None 

448 dkim_selectors: list[str] 

449 verifications: list[ServiceVerificationInfo] 

450 # WHOIS 

451 registrar: str | None 

452 creation_date: str | None 

453 expiration_date: str | None 

454 # Security 

455 dnssec: bool | None 

456 # Meta 

457 error: str | None 

458 cached: bool 

459 

460 

461@router.get("/brokerage-domains/{domain_id}/dns-info", response_model=DnsInfoResponse) 

462async def get_domain_dns_info( 

463 domain_id: int, 

464 user: RequiredUser, 

465 db: Session = Depends(get_db), 

466 refresh: bool = Query(False, description="Bypass cache and fetch fresh data"), 

467): 

468 """ 

469 Get DNS provider and WHOIS information for a domain. 

470 

471 Returns nameserver records, detected DNS provider, registrar info, 

472 domain dates, and DNSSEC status. 

473 

474 Results are cached for performance. Use refresh=true to bypass cache. 

475 """ 

476 domain = db.get(BrokerageDomain, domain_id) 

477 if not domain: 

478 raise HTTPException(status_code=404, detail="Domain not found") 

479 

480 # Authorization check for non-admins 

481 if user.role != "admin" and user.brokerage_id != domain.brokerage_id: 

482 raise HTTPException(status_code=403, detail="Access denied") 

483 

484 # Get DNS info (may be cached) 

485 dns_info = get_dns_info(domain.domain, refresh=refresh) 

486 

487 # Convert provider dict to Pydantic model 

488 provider = None 

489 if dns_info.get("provider"): 

490 p = dns_info["provider"] 

491 provider = DnsProviderInfo( 

492 name=p["name"], 

493 icon=p["icon"], 

494 color=p["color"], 

495 ) 

496 

497 # Convert mail provider dict to Pydantic model 

498 mail_provider = None 

499 if dns_info.get("mail_provider"): 

500 mp = dns_info["mail_provider"] 

501 mail_provider = MailProviderInfo( 

502 name=mp["name"], 

503 icon=mp["icon"], 

504 color=mp["color"], 

505 ) 

506 

507 # Convert MX records 

508 mx_records = [ 

509 MxRecordInfo(priority=mx["priority"], host=mx["host"]) 

510 for mx in dns_info.get("mx_records", []) 

511 ] 

512 

513 # Convert service verifications 

514 verifications = [ 

515 ServiceVerificationInfo( 

516 service=v["service"], 

517 icon=v["icon"], 

518 color=v["color"], 

519 ) 

520 for v in dns_info.get("verifications", []) 

521 ] 

522 

523 return DnsInfoResponse( 

524 # DNS Provider 

525 nameservers=dns_info.get("nameservers", []), 

526 provider=provider, 

527 # Mail 

528 mx_records=mx_records, 

529 mail_provider=mail_provider, 

530 # TXT Records 

531 spf=dns_info.get("spf"), 

532 dmarc=dns_info.get("dmarc"), 

533 dkim_selectors=dns_info.get("dkim_selectors", []), 

534 verifications=verifications, 

535 # WHOIS 

536 registrar=dns_info.get("registrar"), 

537 creation_date=dns_info.get("creation_date"), 

538 expiration_date=dns_info.get("expiration_date"), 

539 # Security 

540 dnssec=dns_info.get("dnssec"), 

541 # Meta 

542 error=dns_info.get("error"), 

543 cached=dns_info.get("cached", False), 

544 ) 

545 

546 

547# ===== Public Lookup Endpoint ===== 

548 

549 

550@router.get("/public/domain-lookup") 

551async def lookup_domain( 

552 domain: str = Query(..., description="Domain to look up"), 

553 db: Session = Depends(get_db), 

554): 

555 """ 

556 Look up which brokerage a domain belongs to. 

557 

558 This is a public endpoint used by the frontend to determine 

559 which brokerage's configuration to load for site.json. 

560 """ 

561 # Normalize domain (lowercase, strip www.) 

562 normalized = domain.lower().strip() 

563 if normalized.startswith("www."): 

564 normalized = normalized[4:] 

565 

566 # Try exact match first 

567 brokerage_domain = db.scalar( 

568 select(BrokerageDomain).where( 

569 BrokerageDomain.domain == normalized, 

570 BrokerageDomain.is_verified == True, 

571 ) 

572 ) 

573 

574 # Also try with www. prefix 

575 if not brokerage_domain: 

576 brokerage_domain = db.scalar( 

577 select(BrokerageDomain).where( 

578 BrokerageDomain.domain == f"www.{normalized}", 

579 BrokerageDomain.is_verified == True, 

580 ) 

581 ) 

582 

583 if not brokerage_domain: 

584 raise HTTPException(status_code=404, detail="Domain not found or not verified") 

585 

586 return { 

587 "domain": brokerage_domain.domain, 

588 "brokerage_id": brokerage_domain.brokerage_id, 

589 "is_primary": brokerage_domain.is_primary, 

590 }