Source code for failures.functions

"""
functions.py module implements functions decorators and utilities to reduce and automate common tasks needed
by this library like labeling the top level scope of functions.
"""
from __future__ import annotations
import types
import inspect
import functools
from typing import Callable, Any, Optional, Union, Dict, Tuple, overload, cast, Awaitable

from typing_extensions import get_origin, get_args

from .core import Reporter, P, T, _invalid

REPORTER_ARG_NAME = 'reporter'
NOTSET = object()

_union_types: list = [Union]
try:
    _union_types.append(types.UnionType)
except AttributeError:
    pass


def _is_param_reporter(name: str, annotation: Any, default: Any) -> bool:
    _type_hinted_reporter = (annotation is Reporter or
                             (get_origin(annotation) in _union_types and
                              Reporter in get_args(annotation)))
    _name_hinted_reporter = name == REPORTER_ARG_NAME and annotation is NOTSET
    _right_default = default is NOTSET or default is None or isinstance(default, Reporter)
    return (_type_hinted_reporter or _name_hinted_reporter) and _right_default


def _get_set_reporter_from_signature(func: Callable, /) -> Tuple[
    Callable[[tuple, Dict[str, Any]], Optional[Reporter]],
    Callable[[tuple, Dict[str, Any], Reporter], Tuple[tuple, Dict[str, Any]]],
]:
    spec = inspect.getfullargspec(func)
    name = get_func_name(func)

    # Check if positional arguments contain the reporter
    for idx, _arg in enumerate(spec.args):
        try:
            default = spec.defaults[idx - len(spec.args)]  # type: ignore
        except (IndexError, TypeError):
            default = NOTSET
        if not _is_param_reporter(_arg, spec.annotations.get(_arg, NOTSET), default):
            continue

        def _get(_args: tuple, _kwargs: Dict[str, Any], /) -> Optional[Reporter]:
            try:
                return _args[idx]
            except IndexError:
                if default is not NOTSET:
                    return default
                raise _invalid(TypeError, f"{name}() is missing the reporter as required positional argument") from None

        def _set(_args: tuple, _kwargs: Dict[str, Any], _reporter: Reporter, /) -> Tuple[tuple, Dict[str, Any]]:
            args = list(_args)
            try:
                args[idx] = _reporter
            except IndexError:
                args.append(_reporter)
            return tuple(args), _kwargs

        return _get, _set

    # Check if keyword arguments contain the reporter
    for _arg in spec.kwonlyargs:
        try:
            default = spec.kwonlydefaults[_arg]  # type: ignore  # Brutal access
        except (KeyError, TypeError):
            default = NOTSET
        if not _is_param_reporter(_arg, spec.annotations.get(_arg, NOTSET), default):
            continue

        def _get(_args: tuple, _kwargs: Dict[str, Any], /) -> Optional[Reporter]:
            try:
                return _kwargs[_arg]
            except KeyError:
                if default is not NOTSET:
                    return default
                raise _invalid(TypeError, f"{name}() is missing the reporter as required keyword argument") from None

        def _set(_args: tuple, _kwargs: Dict[str, Any], _reporter: Reporter, /) -> Tuple[tuple, Dict[str, Any]]:
            _kwargs[_arg] = _reporter
            return _args, _kwargs

        return _get, _set

    # Return default 'get' and 'set' functions
    def _get(_args: tuple, _kwargs: Dict[str, Any], /) -> Optional[Reporter]:  # type: ignore[no-redef]
        return None

    def _set(  # type: ignore[no-redef]
            _args: tuple,
            _kwargs: Dict[str, Any],
            _reporter: Reporter, /
    ) -> Tuple[tuple, Dict[str, Any]]:
        return _args, _kwargs

    return _get, _set


def is_async(func: Callable, /) -> bool:
    while isinstance(func, functools.partial):
        func = func.func
    return inspect.iscoroutinefunction(func) or (
        hasattr(func, '__call__') and
        inspect.iscoroutinefunction(func.__call__)
    )


def get_func_name(func: Callable, /) -> str:
    while isinstance(func, functools.partial):
        func = func.func
    return func.__name__


@overload
def scoped(func: Callable[P, T], /) -> Callable[P, T]: ...
@overload
def scoped(name: str = ..., /) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


[docs] def scoped(arg=None, /): """ Decorates functions to intercept any inner failures and add the function name to its label, and also add the function's name to any reporter passed as argument to that function. This decorator can be either parametrized or not. >>> @scoped ... def function(*args, **kwargs): ... pass or >>> @scoped('my_function') ... def function(*args, **kwargs): ... pass """ def decorator(func_: Callable[P, T], /, name: str = None) -> Callable[P, T]: _get, _set = _get_set_reporter_from_signature(func_) name_ = name or get_func_name(func_) def get_reporter(args, kwargs) -> Reporter: reporter: Optional[Callable[[str], Reporter]] = _get(args, kwargs) if reporter is None: reporter = Reporter elif not isinstance(reporter, Reporter): raise _invalid(TypeError, f"The reporter got wrong type {type(reporter)!r}") return reporter(name_) if is_async(func_): @functools.wraps(func_) async def wrap(*args: P.args, **kwargs: P.kwargs) -> T: reporter = get_reporter(args, kwargs) _args, _kwargs = _set(args, kwargs, reporter) with reporter: return await cast(Callable[P, Awaitable[T]], func_)(*_args, **_kwargs) else: @functools.wraps(func_) def wrap(*args: P.args, **kwargs: P.kwargs) -> T: reporter = get_reporter(args, kwargs) _args, _kwargs = _set(args, kwargs, reporter) with reporter: return func_(*_args, **_kwargs) setattr(wrap, '__signature__', inspect.signature(func_)) return cast(Callable[P, T], wrap) if arg is None: # In case it was used with emtpy parenthesis @scope() return decorator elif callable(arg): # In case it was used without parenthesis @scope return decorator(arg) elif isinstance(arg, str): # In case it was called with a name @scope("name") return lambda func: decorator(func, arg) raise _invalid(TypeError, "@scoped decorator expects a callable as first argument")