Source code for pydotorg.lib.api_versioning

"""API versioning middleware with version negotiation support.

Provides version negotiation via Accept header, URL path, and query parameters.
Includes deprecation warnings for old API versions.
"""

from __future__ import annotations

import json
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING

from litestar.middleware import MiddlewareProtocol

if TYPE_CHECKING:
    from datetime import date

    from litestar.types import ASGIApp, Message, Receive, Scope, Send

_MIN_PARTS_FOR_PATCH = 2


[docs] @dataclass(frozen=True) class APIVersion: """Represents an API version with deprecation metadata.""" major: int minor: int patch: int = 0 deprecated: bool = False sunset_date: date | None = None
[docs] @classmethod def from_string(cls, version_str: str) -> APIVersion: """Parse version from string like 'v1' or '1.2.3'. Args: version_str: Version string in format 'v1', '1', '1.2', or '1.2.3' Returns: Parsed APIVersion instance Raises: ValueError: If version string is invalid """ version_str = version_str.strip().lower() version_str = version_str.removeprefix("v") parts = version_str.split(".") if not parts or not parts[0].isdigit(): msg = f"Invalid version string: {version_str}" raise ValueError(msg) major = int(parts[0]) minor = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0 patch = int(parts[2]) if len(parts) > _MIN_PARTS_FOR_PATCH and parts[2].isdigit() else 0 return cls(major=major, minor=minor, patch=patch)
[docs] def __str__(self) -> str: """Return version as string (e.g., 'v1.0.0').""" return f"v{self.major}.{self.minor}.{self.patch}"
[docs] def to_short_string(self) -> str: """Return short version string (e.g., 'v1').""" return f"v{self.major}"
[docs] def __lt__(self, other: APIVersion) -> bool: """Compare versions for sorting.""" return (self.major, self.minor, self.patch) < (other.major, other.minor, other.patch)
[docs] def __eq__(self, other: object) -> bool: """Compare versions for equality.""" if not isinstance(other, APIVersion): return NotImplemented return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch)
[docs] def __hash__(self) -> int: """Return hash for use in sets and dicts.""" return hash((self.major, self.minor, self.patch))
[docs] class APIVersionRegistry: """Registry of supported API versions with deprecation metadata."""
[docs] def __init__(self) -> None: """Initialize the version registry with default versions.""" self._versions: dict[int, APIVersion] = {} self._default_version: APIVersion | None = None self._register_default_versions()
def _register_default_versions(self) -> None: """Register the default supported API versions.""" v1 = APIVersion(major=1, minor=0, patch=0) self.register(v1, is_default=True)
[docs] def register( self, version: APIVersion, *, is_default: bool = False, ) -> None: """Register a new API version. Args: version: The version to register (with deprecation metadata already set) is_default: Whether this should be the default version """ self._versions[version.major] = version if is_default or self._default_version is None: self._default_version = version
[docs] def get_version(self, major: int) -> APIVersion | None: """Get a version by major version number. Args: major: The major version number Returns: The APIVersion if found, None otherwise """ return self._versions.get(major)
[docs] def get_default(self) -> APIVersion: """Get the default API version. Returns: The default APIVersion Raises: RuntimeError: If no default version is set """ if self._default_version is None: msg = "No default API version configured" raise RuntimeError(msg) return self._default_version
[docs] def is_supported(self, version: APIVersion) -> bool: """Check if a version is supported. Args: version: The version to check Returns: True if the version is supported, False otherwise """ return version.major in self._versions
registry = APIVersionRegistry()
[docs] class APIVersionMiddleware(MiddlewareProtocol): """Middleware that negotiates API version and adds version headers. Version negotiation order: 1. Accept header (application/vnd.pydotorg.v1+json) 2. URL path prefix (/api/v1/...) 3. Query parameter (?api_version=1) 4. Default version (v1) Adds response headers: - X-API-Version: Current version used - X-API-Deprecated: true/false - X-API-Sunset-Date: ISO date if deprecated """ ACCEPT_PATTERN = re.compile(r"application/vnd\.pydotorg\.v(\d+)\+json") PATH_PATTERN = re.compile(r"^/api/v(\d+)/")
[docs] def __init__(self, app: ASGIApp) -> None: """Initialize the middleware. Args: app: The ASGI application """ self.app = app
[docs] async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Process the request and add version headers to response. Args: scope: The ASGI scope receive: The receive callable send: The send callable """ if scope["type"] != "http": await self.app(scope, receive, send) return path = scope.get("path", "") if not path.startswith("/api/"): await self.app(scope, receive, send) return version = self._negotiate_version(scope) scope["api_version"] = version async def send_with_version_headers(message: Message) -> None: """Add version headers to the response. Args: message: The ASGI message """ if message["type"] == "http.response.start": headers = list(message.get("headers", [])) headers.append((b"X-API-Version", str(version).encode())) if version.deprecated: headers.append((b"X-API-Deprecated", b"true")) if version.sunset_date: sunset_str = version.sunset_date.isoformat() headers.append((b"X-API-Sunset-Date", sunset_str.encode())) deprecation_info = { "deprecated": True, "sunset_date": sunset_str, "message": f"API version {version} is deprecated and will be removed on {sunset_str}", "current_version": str(registry.get_default()), } headers.append((b"X-API-Deprecation-Info", json.dumps(deprecation_info).encode())) else: headers.append((b"X-API-Deprecated", b"false")) message["headers"] = headers await send(message) await self.app(scope, receive, send_with_version_headers)
def _negotiate_version(self, scope: Scope) -> APIVersion: """Negotiate API version from request. Checks in order: 1. Accept header 2. URL path 3. Query parameter 4. Default version Args: scope: The ASGI scope Returns: The negotiated APIVersion """ accept_version = self._extract_version_from_accept(scope) if accept_version and registry.is_supported(accept_version): return registry.get_version(accept_version.major) or registry.get_default() path_version = self._extract_version_from_path(scope) if path_version and registry.is_supported(path_version): return registry.get_version(path_version.major) or registry.get_default() query_version = self._extract_version_from_query(scope) if query_version and registry.is_supported(query_version): return registry.get_version(query_version.major) or registry.get_default() return registry.get_default() def _extract_version_from_accept(self, scope: Scope) -> APIVersion | None: """Extract version from Accept header. Looks for: application/vnd.pydotorg.v1+json Args: scope: The ASGI scope Returns: APIVersion if found and valid, None otherwise """ headers_dict = dict(scope.get("headers", [])) accept_header = headers_dict.get(b"accept", b"").decode() match = self.ACCEPT_PATTERN.search(accept_header) if match: try: return APIVersion.from_string(match.group(1)) except ValueError: pass return None def _extract_version_from_path(self, scope: Scope) -> APIVersion | None: """Extract version from URL path. Looks for: /api/v1/... Args: scope: The ASGI scope Returns: APIVersion if found and valid, None otherwise """ path = scope.get("path", "") match = self.PATH_PATTERN.search(path) if match: try: return APIVersion.from_string(match.group(1)) except ValueError: pass return None def _extract_version_from_query(self, scope: Scope) -> APIVersion | None: """Extract version from query parameter. Looks for: ?api_version=1 Args: scope: The ASGI scope Returns: APIVersion if found and valid, None otherwise """ query_string = scope.get("query_string", b"").decode() if not query_string: return None params = dict(param.split("=") for param in query_string.split("&") if "=" in param) version_str = params.get("api_version") if version_str: try: return APIVersion.from_string(version_str) except ValueError: pass return None
[docs] def deprecate_version(major: int, sunset_date: date | None = None) -> None: """Mark an API version as deprecated. Args: major: The major version number to deprecate sunset_date: Optional sunset date when version will be removed Example: >>> from datetime import date >>> deprecate_version(1, sunset_date=date(2025, 12, 31)) """ version = registry.get_version(major) if version: deprecated_version = APIVersion( major=version.major, minor=version.minor, patch=version.patch, deprecated=True, sunset_date=sunset_date, ) registry.register(deprecated_version, is_default=(registry.get_default() == version))
[docs] def add_version(major: int, minor: int = 0, patch: int = 0, *, is_default: bool = False) -> APIVersion: """Add a new API version to the registry. Args: major: Major version number minor: Minor version number (default: 0) patch: Patch version number (default: 0) is_default: Whether this should be the default version Returns: The created APIVersion Example: >>> add_version(2, 0, 0, is_default=True) APIVersion(major=2, minor=0, patch=0) """ version = APIVersion(major=major, minor=minor, patch=patch) registry.register(version, is_default=is_default) return version