Files
zima-apps/Apps/docker-ip-addr-manager/backend/app/dns_sync.py
T

310 lines
12 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import base64
import http.client
import json
from typing import Protocol
from urllib.parse import urlparse
class DnsSyncError(RuntimeError):
pass
class DnsProvider(Protocol):
def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None:
raise NotImplementedError
def delete_a_record(self, fqdn: str) -> None:
raise NotImplementedError
def to_fqdn(entry_name: str, base_domain: str) -> str:
label = _sanitize_label(entry_name)
domain = base_domain.strip().lower().strip(".")
if not domain:
raise DnsSyncError("DNS_BASE_DOMAIN is required when DNS is enabled")
return f"{label}.{domain}"
def _sanitize_label(value: str) -> str:
source = value.strip().lower()
if not source:
raise DnsSyncError("Entry name is required to create DNS record")
cleaned: list[str] = []
prev_dash = False
for ch in source:
if "a" <= ch <= "z" or "0" <= ch <= "9":
cleaned.append(ch)
prev_dash = False
continue
if ch in {" ", "_", "-"} and not prev_dash:
cleaned.append("-")
prev_dash = True
label = "".join(cleaned).strip("-")
if not label:
raise DnsSyncError(f"Entry name cannot produce DNS-safe label: {value!r}")
if len(label) > 63:
raise DnsSyncError("DNS label derived from entry name is too long (max 63)")
return label
@dataclass(frozen=True)
class AdguardConfig:
url: str
username: str
password: str
timeout_seconds: float
class AdguardDnsProvider:
def __init__(self, config: AdguardConfig):
parsed = urlparse(config.url)
if parsed.scheme not in {"http", "https"}:
raise ValueError("ADGUARD_URL must use http or https")
if not parsed.netloc:
raise ValueError("ADGUARD_URL must include host")
self._https = parsed.scheme == "https"
self._host = parsed.hostname or "localhost"
self._port = parsed.port
self._base_path = parsed.path.rstrip("/")
self._username = config.username
self._password = config.password
self._timeout = config.timeout_seconds
self._session_cookie: str | None = None
def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None:
del ttl # AdGuard rewrite records do not expose TTL controls.
rewrites = self._list_rewrites()
for item in rewrites:
if item.get("domain") == fqdn and item.get("answer") == ip:
return
if item.get("domain") == fqdn and item.get("answer") != ip:
self._request("POST", "/control/rewrite/delete", {"domain": fqdn, "answer": item.get("answer", "")})
self._request("POST", "/control/rewrite/add", {"domain": fqdn, "answer": ip})
def delete_a_record(self, fqdn: str) -> None:
rewrites = self._list_rewrites()
for item in rewrites:
if item.get("domain") != fqdn:
continue
self._request("POST", "/control/rewrite/delete", {"domain": fqdn, "answer": item.get("answer", "")})
def _list_rewrites(self) -> list[dict]:
payload = self._request("GET", "/control/rewrite/list", None)
if not isinstance(payload, list):
raise DnsSyncError("AdGuard returned unexpected rewrite list format")
output: list[dict] = []
for item in payload:
if isinstance(item, dict):
output.append(item)
return output
def _request(self, method: str, path: str, payload: dict | None) -> object:
if self._session_cookie is None:
self._login()
return self._request_with_session(method, path, payload, retry_on_auth=True)
def _login(self) -> None:
body = {"name": self._username, "password": self._password}
payload, headers = self._raw_request("POST", "/control/login", body, include_auth=False)
if headers is None:
raise DnsSyncError("AdGuard login failed: missing response headers")
cookie = headers.get("set-cookie", "")
session = ""
for piece in cookie.split(";"):
piece = piece.strip()
if piece.startswith("agh_session="):
session = piece
break
if not session:
raise DnsSyncError("AdGuard login failed: no agh_session cookie")
self._session_cookie = session
del payload
def _request_with_session(self, method: str, path: str, payload: dict | None, retry_on_auth: bool) -> object:
body, _ = self._raw_request(method, path, payload, include_auth=True)
if isinstance(body, dict) and body.get("message") == "unauthorized":
if retry_on_auth:
self._session_cookie = None
self._login()
return self._request_with_session(method, path, payload, retry_on_auth=False)
raise DnsSyncError("AdGuard request unauthorized")
return body
def _raw_request(
self, method: str, path: str, payload: dict | None, include_auth: bool
) -> tuple[object, dict[str, str] | None]:
conn: http.client.HTTPConnection | http.client.HTTPSConnection
if self._https:
conn = http.client.HTTPSConnection(self._host, self._port, timeout=self._timeout)
else:
conn = http.client.HTTPConnection(self._host, self._port, timeout=self._timeout)
request_path = f"{self._base_path}{path}"
raw = ""
headers = {"Content-Type": "application/json"}
if include_auth and self._session_cookie:
headers["Cookie"] = self._session_cookie
if payload is not None:
raw = json.dumps(payload)
try:
conn.request(method, request_path, body=raw, headers=headers)
response = conn.getresponse()
body_text = response.read().decode("utf-8", errors="replace")
response_headers = {k.lower(): v for k, v in response.getheaders()}
except OSError as exc:
raise DnsSyncError(f"AdGuard request failed for {path}: {exc}") from exc
finally:
conn.close()
if response.status < 200 or response.status >= 300:
raise DnsSyncError(
f"AdGuard request failed for {path}: HTTP {response.status} {response.reason}; body={body_text[:400]}"
)
if not body_text.strip():
return {}, response_headers
try:
return json.loads(body_text), response_headers
except json.JSONDecodeError:
return body_text, response_headers
@dataclass(frozen=True)
class Rfc2136Config:
server: str
zone: str
port: int
timeout_seconds: float
tsig_key_name: str
tsig_secret: str
tsig_algorithm: str
class Rfc2136DnsProvider:
def __init__(self, config: Rfc2136Config):
if not config.server.strip():
raise ValueError("RFC2136_SERVER is required")
if not config.zone.strip():
raise ValueError("RFC2136_ZONE is required")
self._server = config.server.strip()
self._zone = config.zone.strip().rstrip(".")
self._port = config.port
self._timeout = config.timeout_seconds
self._key_name = config.tsig_key_name.strip()
self._secret = config.tsig_secret.strip()
self._algorithm = config.tsig_algorithm.strip() or "hmac-sha256"
def upsert_a_record(self, fqdn: str, ip: str, ttl: int) -> None:
rcode, tsigkeyring, update, query = self._dns_modules()
zone_text = self._zone_with_dot()
keyring = self._keyring_or_none(tsigkeyring)
target = self._absolute_name(fqdn)
try:
req = update.Update(zone_text, keyring=keyring, keyname=self._key_name or None, keyalgorithm=self._algorithm)
req.delete(target, "A")
req.add(target, int(ttl), "A", ip)
response = query.tcp(req, self._server, port=self._port, timeout=self._timeout)
except Exception as exc: # noqa: BLE001
raise DnsSyncError(f"RFC2136 upsert failed for {fqdn} -> {ip}: {exc}") from exc
if response.rcode() != rcode.NOERROR:
text = rcode.to_text(response.rcode())
raise DnsSyncError(f"RFC2136 upsert failed for {fqdn}: {text}")
def delete_a_record(self, fqdn: str) -> None:
rcode, tsigkeyring, update, query = self._dns_modules()
zone_text = self._zone_with_dot()
keyring = self._keyring_or_none(tsigkeyring)
target = self._absolute_name(fqdn)
try:
req = update.Update(zone_text, keyring=keyring, keyname=self._key_name or None, keyalgorithm=self._algorithm)
req.delete(target, "A")
response = query.tcp(req, self._server, port=self._port, timeout=self._timeout)
except Exception as exc: # noqa: BLE001
raise DnsSyncError(f"RFC2136 delete failed for {fqdn}: {exc}") from exc
if response.rcode() != rcode.NOERROR:
text = rcode.to_text(response.rcode())
raise DnsSyncError(f"RFC2136 delete failed for {fqdn}: {text}")
def _dns_modules(self):
try:
import dns.query as query
import dns.rcode as rcode
import dns.tsigkeyring as tsigkeyring
import dns.update as update
except ImportError as exc:
raise DnsSyncError("dnspython is required for RFC2136 mode") from exc
return rcode, tsigkeyring, update, query
def _keyring_or_none(self, tsigkeyring):
if not self._key_name and not self._secret:
return None
if not self._key_name or not self._secret:
raise DnsSyncError("RFC2136 TSIG requires both key name and secret")
key_name = self._key_name if self._key_name.endswith(".") else f"{self._key_name}."
try:
base64.b64decode(self._secret, validate=True)
except Exception as exc: # noqa: BLE001
raise DnsSyncError("RFC2136_TSIG_SECRET must be valid base64") from exc
if self._algorithm not in {"hmac-sha256", "hmac-sha512", "hmac-sha1", "hmac-md5.sig-alg.reg.int"}:
raise DnsSyncError(f"Unsupported TSIG algorithm: {self._algorithm}")
return tsigkeyring.from_text({key_name: self._secret})
def _zone_with_dot(self) -> str:
return self._zone if self._zone.endswith(".") else f"{self._zone}."
def _absolute_name(self, fqdn: str) -> str:
return fqdn if fqdn.endswith(".") else f"{fqdn}."
def build_dns_provider(
provider_name: str,
*,
adguard_url: str,
adguard_username: str,
adguard_password: str,
rfc2136_server: str,
rfc2136_zone: str,
rfc2136_port: int,
rfc2136_tsig_key_name: str,
rfc2136_tsig_secret: str,
rfc2136_tsig_algorithm: str,
timeout_seconds: float,
) -> DnsProvider | None:
mode = provider_name.strip().lower()
if not mode or mode == "none":
return None
if mode == "adguard":
if not adguard_url.strip():
raise DnsSyncError("ADGUARD_URL is required for DNS_PROVIDER=adguard")
if not adguard_username.strip() or not adguard_password.strip():
raise DnsSyncError("ADGUARD_USERNAME and ADGUARD_PASSWORD are required for DNS_PROVIDER=adguard")
return AdguardDnsProvider(
AdguardConfig(
url=adguard_url,
username=adguard_username,
password=adguard_password,
timeout_seconds=timeout_seconds,
)
)
if mode == "rfc2136":
return Rfc2136DnsProvider(
Rfc2136Config(
server=rfc2136_server,
zone=rfc2136_zone,
port=rfc2136_port,
timeout_seconds=timeout_seconds,
tsig_key_name=rfc2136_tsig_key_name,
tsig_secret=rfc2136_tsig_secret,
tsig_algorithm=rfc2136_tsig_algorithm,
)
)
raise DnsSyncError(f"Unsupported DNS_PROVIDER: {provider_name}")