"""
Support for multirange types adaptation.
"""

# Copyright (C) 2021 The Psycopg Team

from __future__ import annotations

from decimal import Decimal
from typing import Any, Generic, Iterable, MutableSequence, overload, TYPE_CHECKING
from datetime import date, datetime

from .. import sql
from .. import _oids
from .. import errors as e
from .. import postgres
from ..pq import Format
from ..abc import AdaptContext, Buffer, Dumper, DumperKey, Query
from ..adapt import RecursiveDumper, RecursiveLoader, PyFormat
from .._oids import INVALID_OID, TEXT_OID
from .._compat import cache
from .._struct import pack_len, unpack_len
from .._typeinfo import TypeInfo, TypesRegistry

from .range import Range, T, load_range_text, load_range_binary
from .range import dump_range_text, dump_range_binary, fail_dump

if TYPE_CHECKING:
    from .._connection_base import BaseConnection


class MultirangeInfo(TypeInfo):
    """Manage information about a multirange type."""

    def __init__(
        self,
        name: str,
        oid: int,
        array_oid: int,
        *,
        regtype: str = "",
        range_oid: int,
        subtype_oid: int,
    ):
        super().__init__(name, oid, array_oid, regtype=regtype)
        self.range_oid = range_oid
        self.subtype_oid = subtype_oid

    @classmethod
    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
        if conn.info.server_version < 140000:
            raise e.NotSupportedError(
                "multirange types are only available from PostgreSQL 14"
            )
        return sql.SQL(
            """\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
    t.oid::regtype::text AS regtype,
    r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngmultitypid
WHERE t.oid = {regtype}
"""
        ).format(regtype=cls._to_regtype(conn))

    def _added(self, registry: TypesRegistry) -> None:
        # Map multiranges ranges and subtypes to info
        registry._registry[MultirangeInfo, self.range_oid] = self
        registry._registry[MultirangeInfo, self.subtype_oid] = self


class Multirange(MutableSequence[Range[T]]):
    """Python representation for a PostgreSQL multirange type.

    :param items: Sequence of ranges to initialise the object.
    """

    def __init__(self, items: Iterable[Range[T]] = ()):
        self._ranges: list[Range[T]] = list(map(self._check_type, items))

    def _check_type(self, item: Any) -> Range[Any]:
        if not isinstance(item, Range):
            raise TypeError(
                f"Multirange is a sequence of Range, got {type(item).__name__}"
            )
        return item

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self._ranges!r})"

    def __str__(self) -> str:
        return f"{{{', '.join(map(str, self._ranges))}}}"

    @overload
    def __getitem__(self, index: int) -> Range[T]: ...

    @overload
    def __getitem__(self, index: slice) -> Multirange[T]: ...

    def __getitem__(self, index: int | slice) -> Range[T] | Multirange[T]:
        if isinstance(index, int):
            return self._ranges[index]
        else:
            return Multirange(self._ranges[index])

    def __len__(self) -> int:
        return len(self._ranges)

    @overload
    def __setitem__(self, index: int, value: Range[T]) -> None: ...

    @overload
    def __setitem__(self, index: slice, value: Iterable[Range[T]]) -> None: ...

    def __setitem__(
        self, index: int | slice, value: Range[T] | Iterable[Range[T]]
    ) -> None:
        if isinstance(index, int):
            self._check_type(value)
            self._ranges[index] = self._check_type(value)
        elif not isinstance(value, Iterable):
            raise TypeError("can only assign an iterable")
        else:
            value = map(self._check_type, value)
            self._ranges[index] = value

    def __delitem__(self, index: int | slice) -> None:
        del self._ranges[index]

    def insert(self, index: int, value: Range[T]) -> None:
        self._ranges.insert(index, self._check_type(value))

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, Multirange):
            return False
        return self._ranges == other._ranges

    # Order is arbitrary but consistent

    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, Multirange):
            return NotImplemented
        return self._ranges < other._ranges

    def __le__(self, other: Any) -> bool:
        return self == other or self < other  # type: ignore

    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, Multirange):
            return NotImplemented
        return self._ranges > other._ranges

    def __ge__(self, other: Any) -> bool:
        return self == other or self > other  # type: ignore


# Subclasses to specify a specific subtype. Usually not needed


class Int4Multirange(Multirange[int]):
    pass


class Int8Multirange(Multirange[int]):
    pass


class NumericMultirange(Multirange[Decimal]):
    pass


class DateMultirange(Multirange[date]):
    pass


class TimestampMultirange(Multirange[datetime]):
    pass


class TimestamptzMultirange(Multirange[datetime]):
    pass


class BaseMultirangeDumper(RecursiveDumper):
    def __init__(self, cls: type, context: AdaptContext | None = None):
        super().__init__(cls, context)
        self.sub_dumper: Dumper | None = None
        self._adapt_format = PyFormat.from_pq(self.format)

    def get_key(self, obj: Multirange[Any], format: PyFormat) -> DumperKey:
        # If we are a subclass whose oid is specified we don't need upgrade
        if self.cls is not Multirange:
            return self.cls

        item = self._get_item(obj)
        if item is not None:
            sd = self._tx.get_dumper(item, self._adapt_format)
            return (self.cls, sd.get_key(item, format))
        else:
            return (self.cls,)

    def upgrade(self, obj: Multirange[Any], format: PyFormat) -> BaseMultirangeDumper:
        # If we are a subclass whose oid is specified we don't need upgrade
        if self.cls is not Multirange:
            return self

        item = self._get_item(obj)
        if item is None:
            return self

        dumper: BaseMultirangeDumper
        if type(item) is int:
            # postgres won't cast int4range -> int8range so we must use
            # text format and unknown oid here
            sd = self._tx.get_dumper(item, PyFormat.TEXT)
            dumper = MultirangeDumper(self.cls, self._tx)
            dumper.sub_dumper = sd
            dumper.oid = INVALID_OID
            return dumper

        sd = self._tx.get_dumper(item, format)
        dumper = type(self)(self.cls, self._tx)
        dumper.sub_dumper = sd
        if sd.oid == INVALID_OID and isinstance(item, str):
            # Work around the normal mapping where text is dumped as unknown
            dumper.oid = self._get_multirange_oid(TEXT_OID)
        else:
            dumper.oid = self._get_multirange_oid(sd.oid)

        return dumper

    def _get_item(self, obj: Multirange[Any]) -> Any:
        """
        Return a member representative of the multirange
        """
        for r in obj:
            if r.lower is not None:
                return r.lower
            if r.upper is not None:
                return r.upper
        return None

    def _get_multirange_oid(self, sub_oid: int) -> int:
        """
        Return the oid of the range from the oid of its elements.
        """
        info = self._tx.adapters.types.get_by_subtype(MultirangeInfo, sub_oid)
        return info.oid if info else INVALID_OID


class MultirangeDumper(BaseMultirangeDumper):
    """
    Dumper for multirange types.

    The dumper can upgrade to one specific for a different range type.
    """

    def dump(self, obj: Multirange[Any]) -> Buffer | None:
        if not obj:
            return b"{}"

        item = self._get_item(obj)
        if item is not None:
            dump = self._tx.get_dumper(item, self._adapt_format).dump
        else:
            dump = fail_dump

        out: list[Buffer] = [b"{"]
        for r in obj:
            out.append(dump_range_text(r, dump))
            out.append(b",")
        out[-1] = b"}"
        return b"".join(out)


class MultirangeBinaryDumper(BaseMultirangeDumper):
    format = Format.BINARY

    def dump(self, obj: Multirange[Any]) -> Buffer | None:
        item = self._get_item(obj)
        if item is not None:
            dump = self._tx.get_dumper(item, self._adapt_format).dump
        else:
            dump = fail_dump

        out: list[Buffer] = [pack_len(len(obj))]
        for r in obj:
            data = dump_range_binary(r, dump)
            out.append(pack_len(len(data)))
            out.append(data)
        return b"".join(out)


class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
    subtype_oid: int

    def __init__(self, oid: int, context: AdaptContext | None = None):
        super().__init__(oid, context)
        self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load


class MultirangeLoader(BaseMultirangeLoader[T]):
    def load(self, data: Buffer) -> Multirange[T]:
        if not data or data[0] != _START_INT:
            raise e.DataError(
                "malformed multirange starting with"
                f" {bytes(data[:1]).decode('utf8', 'replace')}"
            )

        out = Multirange[T]()
        if data == b"{}":
            return out

        pos = 1
        data = data[pos:]
        try:
            while True:
                r, pos = load_range_text(data, self._load)
                out.append(r)

                sep = data[pos]  # can raise IndexError
                if sep == _SEP_INT:
                    data = data[pos + 1 :]
                    continue
                elif sep == _END_INT:
                    if len(data) == pos + 1:
                        return out
                    else:
                        raise e.DataError(
                            "malformed multirange: data after closing brace"
                        )
                else:
                    raise e.DataError(
                        f"malformed multirange: found unexpected {chr(sep)}"
                    )

        except IndexError:
            raise e.DataError("malformed multirange: separator missing")

        return out


_SEP_INT = ord(",")
_START_INT = ord("{")
_END_INT = ord("}")


class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
    format = Format.BINARY

    def load(self, data: Buffer) -> Multirange[T]:
        nelems = unpack_len(data, 0)[0]
        pos = 4
        out = Multirange[T]()
        for i in range(nelems):
            length = unpack_len(data, pos)[0]
            pos += 4
            out.append(load_range_binary(data[pos : pos + length], self._load))
            pos += length

        if pos != len(data):
            raise e.DataError("unexpected trailing data in multirange")

        return out


def register_multirange(
    info: MultirangeInfo, context: AdaptContext | None = None
) -> None:
    """Register the adapters to load and dump a multirange type.

    :param info: The object with the information about the range to register.
    :param context: The context where to register the adapters. If `!None`,
        register it globally.

    Register loaders so that loading data of this type will result in a `Range`
    with bounds parsed as the right subtype.

    .. note::

        Registering the adapters doesn't affect objects already created, even
        if they are children of the registered context. For instance,
        registering the adapter globally doesn't affect already existing
        connections.
    """
    # A friendly error warning instead of an AttributeError in case fetch()
    # failed and it wasn't noticed.
    if not info:
        raise TypeError("no info passed. Is the requested multirange available?")

    # Register arrays and type info
    info.register(context)

    adapters = context.adapters if context else postgres.adapters

    # generate and register a customized text loader
    loader: type[BaseMultirangeLoader[Any]]
    loader = _make_loader(info.name, info.subtype_oid)
    adapters.register_loader(info.oid, loader)

    # generate and register a customized binary loader
    loader = _make_binary_loader(info.name, info.subtype_oid)
    adapters.register_loader(info.oid, loader)


# Cache all dynamically-generated types to avoid leaks in case the types
# cannot be GC'd.


@cache
def _make_loader(name: str, oid: int) -> type[MultirangeLoader[Any]]:
    return type(f"{name.title()}Loader", (MultirangeLoader,), {"subtype_oid": oid})


@cache
def _make_binary_loader(name: str, oid: int) -> type[MultirangeBinaryLoader[Any]]:
    return type(
        f"{name.title()}BinaryLoader", (MultirangeBinaryLoader,), {"subtype_oid": oid}
    )


# Text dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.


class Int4MultirangeDumper(MultirangeDumper):
    oid = _oids.INT4MULTIRANGE_OID


class Int8MultirangeDumper(MultirangeDumper):
    oid = _oids.INT8MULTIRANGE_OID


class NumericMultirangeDumper(MultirangeDumper):
    oid = _oids.NUMMULTIRANGE_OID


class DateMultirangeDumper(MultirangeDumper):
    oid = _oids.DATEMULTIRANGE_OID


class TimestampMultirangeDumper(MultirangeDumper):
    oid = _oids.TSMULTIRANGE_OID


class TimestamptzMultirangeDumper(MultirangeDumper):
    oid = _oids.TSTZMULTIRANGE_OID


# Binary dumpers for builtin multirange types wrappers
# These are registered on specific subtypes so that the upgrade mechanism
# doesn't kick in.


class Int4MultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.INT4MULTIRANGE_OID


class Int8MultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.INT8MULTIRANGE_OID


class NumericMultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.NUMMULTIRANGE_OID


class DateMultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.DATEMULTIRANGE_OID


class TimestampMultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.TSMULTIRANGE_OID


class TimestamptzMultirangeBinaryDumper(MultirangeBinaryDumper):
    oid = _oids.TSTZMULTIRANGE_OID


# Text loaders for builtin multirange types


class Int4MultirangeLoader(MultirangeLoader[int]):
    subtype_oid = _oids.INT4_OID


class Int8MultirangeLoader(MultirangeLoader[int]):
    subtype_oid = _oids.INT8_OID


class NumericMultirangeLoader(MultirangeLoader[Decimal]):
    subtype_oid = _oids.NUMERIC_OID


class DateMultirangeLoader(MultirangeLoader[date]):
    subtype_oid = _oids.DATE_OID


class TimestampMultirangeLoader(MultirangeLoader[datetime]):
    subtype_oid = _oids.TIMESTAMP_OID


class TimestampTZMultirangeLoader(MultirangeLoader[datetime]):
    subtype_oid = _oids.TIMESTAMPTZ_OID


# Binary loaders for builtin multirange types


class Int4MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
    subtype_oid = _oids.INT4_OID


class Int8MultirangeBinaryLoader(MultirangeBinaryLoader[int]):
    subtype_oid = _oids.INT8_OID


class NumericMultirangeBinaryLoader(MultirangeBinaryLoader[Decimal]):
    subtype_oid = _oids.NUMERIC_OID


class DateMultirangeBinaryLoader(MultirangeBinaryLoader[date]):
    subtype_oid = _oids.DATE_OID


class TimestampMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
    subtype_oid = _oids.TIMESTAMP_OID


class TimestampTZMultirangeBinaryLoader(MultirangeBinaryLoader[datetime]):
    subtype_oid = _oids.TIMESTAMPTZ_OID


def register_default_adapters(context: AdaptContext) -> None:
    adapters = context.adapters
    adapters.register_dumper(Multirange, MultirangeBinaryDumper)
    adapters.register_dumper(Multirange, MultirangeDumper)
    adapters.register_dumper(Int4Multirange, Int4MultirangeDumper)
    adapters.register_dumper(Int8Multirange, Int8MultirangeDumper)
    adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
    adapters.register_dumper(DateMultirange, DateMultirangeDumper)
    adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
    adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
    adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
    adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
    adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
    adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
    adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
    adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
    adapters.register_loader("int4multirange", Int4MultirangeLoader)
    adapters.register_loader("int8multirange", Int8MultirangeLoader)
    adapters.register_loader("nummultirange", NumericMultirangeLoader)
    adapters.register_loader("datemultirange", DateMultirangeLoader)
    adapters.register_loader("tsmultirange", TimestampMultirangeLoader)
    adapters.register_loader("tstzmultirange", TimestampTZMultirangeLoader)
    adapters.register_loader("int4multirange", Int4MultirangeBinaryLoader)
    adapters.register_loader("int8multirange", Int8MultirangeBinaryLoader)
    adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
    adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
    adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
    adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
