from __future__ import annotations from dataclasses import dataclass import base64 import http.client import json from typing import Protocol from urllib.parse import urlparse class DnsSyncError(RuntimeError): pass class DnsProvider(Protocol): def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None: raise NotImplementedError def delete_a_record(self, fqdn: str) -> None: raise NotImplementedError def to_fqdn(entry_name: str, base_domain: str) -> str: label = _sanitize_label(entry_name) domain = base_domain.strip().lower().strip(".") if not domain: raise DnsSyncError("DNS_BASE_DOMAIN is required when DNS is enabled") return f"{label}.{domain}" def _sanitize_label(value: str) -> str: source = value.strip().lower() if not source: raise DnsSyncError("Entry name is required to create DNS record") cleaned: list[str] = [] prev_dash = False for ch in source: if "a" <= ch <= "z" or "0" <= ch <= "9": cleaned.append(ch) prev_dash = False continue if ch in {" ", "_", "-"} and not prev_dash: cleaned.append("-") prev_dash = True label = "".join(cleaned).strip("-") if not label: raise DnsSyncError(f"Entry name cannot produce DNS-safe label: {value!r}") if len(label) > 63: raise DnsSyncError("DNS label derived from entry name is too long (max 63)") return label @dataclass(frozen=True) class AdguardConfig: url: str username: str password: str timeout_seconds: float class AdguardDnsProvider: def __init__(self, config: AdguardConfig): parsed = urlparse(config.url) if parsed.scheme not in {"http", "https"}: raise ValueError("ADGUARD_URL must use http or https") if not parsed.netloc: raise ValueError("ADGUARD_URL must include host") self._https = parsed.scheme == "https" self._host = parsed.hostname or "localhost" self._port = parsed.port self._base_path = parsed.path.rstrip("/") self._username = config.username self._password = config.password self._timeout = config.timeout_seconds self._session_cookie: str | None = None def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None: del ttl # AdGuard rewrite records do not expose TTL controls. rewrites = self._list_rewrites() for item in rewrites: if item.get("domain") == fqdn and item.get("answer") == ip: return if item.get("domain") == fqdn and item.get("answer") != ip: self._request("POST", "/control/rewrite/delete", {"domain": fqdn, "answer": item.get("answer", "")}) self._request("POST", "/control/rewrite/add", {"domain": fqdn, "answer": ip}) def delete_a_record(self, fqdn: str) -> None: rewrites = self._list_rewrites() for item in rewrites: if item.get("domain") != fqdn: continue self._request("POST", "/control/rewrite/delete", {"domain": fqdn, "answer": item.get("answer", "")}) def _list_rewrites(self) -> list[dict]: payload = self._request("GET", "/control/rewrite/list", None) if not isinstance(payload, list): raise DnsSyncError("AdGuard returned unexpected rewrite list format") output: list[dict] = [] for item in payload: if isinstance(item, dict): output.append(item) return output def _request(self, method: str, path: str, payload: dict | None) -> object: if self._session_cookie is None: self._login() return self._request_with_session(method, path, payload, retry_on_auth=True) def _login(self) -> None: body = {"name": self._username, "password": self._password} payload, headers = self._raw_request("POST", "/control/login", body, include_auth=False) if headers is None: raise DnsSyncError("AdGuard login failed: missing response headers") cookie = headers.get("set-cookie", "") session = "" for piece in cookie.split(";"): piece = piece.strip() if piece.startswith("agh_session="): session = piece break if not session: raise DnsSyncError("AdGuard login failed: no agh_session cookie") self._session_cookie = session del payload def _request_with_session(self, method: str, path: str, payload: dict | None, retry_on_auth: bool) -> object: body, _ = self._raw_request(method, path, payload, include_auth=True) if isinstance(body, dict) and body.get("message") == "unauthorized": if retry_on_auth: self._session_cookie = None self._login() return self._request_with_session(method, path, payload, retry_on_auth=False) raise DnsSyncError("AdGuard request unauthorized") return body def _raw_request( self, method: str, path: str, payload: dict | None, include_auth: bool ) -> tuple[object, dict[str, str] | None]: conn: http.client.HTTPConnection | http.client.HTTPSConnection if self._https: conn = http.client.HTTPSConnection(self._host, self._port, timeout=self._timeout) else: conn = http.client.HTTPConnection(self._host, self._port, timeout=self._timeout) request_path = f"{self._base_path}{path}" raw = "" headers = {"Content-Type": "application/json"} if include_auth and self._session_cookie: headers["Cookie"] = self._session_cookie if payload is not None: raw = json.dumps(payload) try: conn.request(method, request_path, body=raw, headers=headers) response = conn.getresponse() body_text = response.read().decode("utf-8", errors="replace") response_headers = {k.lower(): v for k, v in response.getheaders()} except OSError as exc: raise DnsSyncError(f"AdGuard request failed for {path}: {exc}") from exc finally: conn.close() if response.status < 200 or response.status >= 300: raise DnsSyncError( f"AdGuard request failed for {path}: HTTP {response.status} {response.reason}; body={body_text[:400]}" ) if not body_text.strip(): return {}, response_headers try: return json.loads(body_text), response_headers except json.JSONDecodeError: return body_text, response_headers @dataclass(frozen=True) class Rfc2136Config: server: str zone: str port: int timeout_seconds: float tsig_key_name: str tsig_secret: str tsig_algorithm: str class Rfc2136DnsProvider: def __init__(self, config: Rfc2136Config): if not config.server.strip(): raise ValueError("RFC2136_SERVER is required") if not config.zone.strip(): raise ValueError("RFC2136_ZONE is required") self._server = config.server.strip() self._zone = config.zone.strip().rstrip(".") self._port = config.port self._timeout = config.timeout_seconds self._key_name = config.tsig_key_name.strip() self._secret = config.tsig_secret.strip() self._algorithm = config.tsig_algorithm.strip() or "hmac-sha256" def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None: rcode, tsigkeyring, update, query = self._dns_modules() zone_text = self._zone_with_dot() keyring = self._keyring_or_none(tsigkeyring) target = self._absolute_name(fqdn) try: req = update.Update(zone_text, keyring=keyring, keyname=self._key_name or None, keyalgorithm=self._algorithm) req.delete(target, "A") req.add(target, int(ttl), "A", ip) response = query.tcp(req, self._server, port=self._port, timeout=self._timeout) except Exception as exc: # noqa: BLE001 raise DnsSyncError(f"RFC2136 upsert failed for {fqdn} -> {ip}: {exc}") from exc if response.rcode() != rcode.NOERROR: text = rcode.to_text(response.rcode()) raise DnsSyncError(f"RFC2136 upsert failed for {fqdn}: {text}") def delete_a_record(self, fqdn: str) -> None: rcode, tsigkeyring, update, query = self._dns_modules() zone_text = self._zone_with_dot() keyring = self._keyring_or_none(tsigkeyring) target = self._absolute_name(fqdn) try: req = update.Update(zone_text, keyring=keyring, keyname=self._key_name or None, keyalgorithm=self._algorithm) req.delete(target, "A") response = query.tcp(req, self._server, port=self._port, timeout=self._timeout) except Exception as exc: # noqa: BLE001 raise DnsSyncError(f"RFC2136 delete failed for {fqdn}: {exc}") from exc if response.rcode() != rcode.NOERROR: text = rcode.to_text(response.rcode()) raise DnsSyncError(f"RFC2136 delete failed for {fqdn}: {text}") def _dns_modules(self): try: import dns.query as query import dns.rcode as rcode import dns.tsigkeyring as tsigkeyring import dns.update as update except ImportError as exc: raise DnsSyncError("dnspython is required for RFC2136 mode") from exc return rcode, tsigkeyring, update, query def _keyring_or_none(self, tsigkeyring): if not self._key_name and not self._secret: return None if not self._key_name or not self._secret: raise DnsSyncError("RFC2136 TSIG requires both key name and secret") key_name = self._key_name if self._key_name.endswith(".") else f"{self._key_name}." try: base64.b64decode(self._secret, validate=True) except Exception as exc: # noqa: BLE001 raise DnsSyncError("RFC2136_TSIG_SECRET must be valid base64") from exc if self._algorithm not in {"hmac-sha256", "hmac-sha512", "hmac-sha1", "hmac-md5.sig-alg.reg.int"}: raise DnsSyncError(f"Unsupported TSIG algorithm: {self._algorithm}") return tsigkeyring.from_text({key_name: self._secret}) def _zone_with_dot(self) -> str: return self._zone if self._zone.endswith(".") else f"{self._zone}." def _absolute_name(self, fqdn: str) -> str: return fqdn if fqdn.endswith(".") else f"{fqdn}." def build_dns_provider( provider_name: str, *, adguard_url: str, adguard_username: str, adguard_password: str, rfc2136_server: str, rfc2136_zone: str, rfc2136_port: int, rfc2136_tsig_key_name: str, rfc2136_tsig_secret: str, rfc2136_tsig_algorithm: str, timeout_seconds: float, ) -> DnsProvider | None: mode = provider_name.strip().lower() if not mode or mode == "none": return None if mode == "adguard": if not adguard_url.strip(): raise DnsSyncError("ADGUARD_URL is required for DNS_PROVIDER=adguard") if not adguard_username.strip() or not adguard_password.strip(): raise DnsSyncError("ADGUARD_USERNAME and ADGUARD_PASSWORD are required for DNS_PROVIDER=adguard") return AdguardDnsProvider( AdguardConfig( url=adguard_url, username=adguard_username, password=adguard_password, timeout_seconds=timeout_seconds, ) ) if mode == "rfc2136": return Rfc2136DnsProvider( Rfc2136Config( server=rfc2136_server, zone=rfc2136_zone, port=rfc2136_port, timeout_seconds=timeout_seconds, tsig_key_name=rfc2136_tsig_key_name, tsig_secret=rfc2136_tsig_secret, tsig_algorithm=rfc2136_tsig_algorithm, ) ) raise DnsSyncError(f"Unsupported DNS_PROVIDER: {provider_name}")