299 lines
9.6 KiB
Python
299 lines
9.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Contracts to enforce loader objects
|
|
"""
|
|
|
|
# Import python libs
|
|
import asyncio
|
|
import functools
|
|
import inspect
|
|
import os
|
|
from collections import namedtuple
|
|
from types import ModuleType
|
|
from typing import Dict, Iterable, List
|
|
|
|
# Import pop libs
|
|
import pop.exc
|
|
import pop.hub
|
|
import pop.verify
|
|
|
|
|
|
class ContractedContext(
|
|
namedtuple(
|
|
"ContractedContext",
|
|
("func", "args", "kwargs", "signature", "ret", "cache", "ref"),
|
|
)
|
|
):
|
|
"""
|
|
Contracted function calling context
|
|
"""
|
|
|
|
def __new__(
|
|
cls,
|
|
func: functools.partial,
|
|
args: Iterable,
|
|
kwargs: Dict,
|
|
signature,
|
|
ref,
|
|
ret=None,
|
|
cache=None,
|
|
): # pylint: disable=too-many-arguments
|
|
if cache is None:
|
|
cache = {}
|
|
return super(ContractedContext, cls).__new__(
|
|
cls, func, list(args), kwargs, signature, ret, cache, ref
|
|
)
|
|
|
|
def get_argument(self, name):
|
|
"""
|
|
Return the value corresponding to a function argument after binding the contract context
|
|
argument and keyword arguments to the function signature.
|
|
"""
|
|
return self.get_arguments()[name]
|
|
|
|
def get_arguments(self):
|
|
"""
|
|
Return a dictionary of all arguments that will be passed to the function and their
|
|
values, including default arguments.
|
|
"""
|
|
if "__bound_signature__" not in self.cache:
|
|
try:
|
|
self.cache["__bound_signature__"] = self.signature.bind(
|
|
*self.args, **self.kwargs
|
|
)
|
|
except TypeError as e:
|
|
for frame in inspect.trace(0):
|
|
if frame.function == "bind" and frame.filename.endswith(
|
|
os.sep + "inspect.py"
|
|
):
|
|
raise pop.exc.BindError(e)
|
|
raise
|
|
# Apply any default values from the signature
|
|
self.cache["__bound_signature__"].apply_defaults()
|
|
return self.cache["__bound_signature__"].arguments
|
|
|
|
|
|
def load_contract(
|
|
contracts: List["Contracted"],
|
|
default_contracts: List[str],
|
|
mod: ModuleType,
|
|
name: str,
|
|
) -> List:
|
|
"""
|
|
return a Contract object loaded up
|
|
Dynamically create the correct Contracted type
|
|
:param contracts: Contracts functions to add to the sub
|
|
:param default_contracts: The contracts that have been marked as defaults
|
|
:param mod: A loader module
|
|
:param name: The name of the module to get from the loader
|
|
"""
|
|
raws = []
|
|
if not contracts:
|
|
return raws
|
|
loaded_contracts = []
|
|
if hasattr(contracts, name):
|
|
loaded_contracts.append(name)
|
|
raws.append(getattr(contracts, name))
|
|
if hasattr(contracts, "init"):
|
|
loaded_contracts.append("init")
|
|
raws.append(getattr(contracts, "init"))
|
|
if default_contracts:
|
|
for contract in default_contracts:
|
|
if contract in loaded_contracts:
|
|
continue
|
|
loaded_contracts.append(contract)
|
|
raws.append(getattr(contracts, contract))
|
|
if hasattr(mod, "__contracts__"):
|
|
cnames = getattr(mod, "__contracts__")
|
|
if not isinstance(cnames, (list, tuple)):
|
|
cnames = cnames.split(",")
|
|
for cname in cnames:
|
|
if cname in contracts:
|
|
if cname in loaded_contracts:
|
|
continue
|
|
loaded_contracts.append(cname)
|
|
raws.append(getattr(contracts, cname))
|
|
return raws
|
|
|
|
|
|
class Wrapper:
|
|
def __init__(self, func: functools.partial, ref: str, name: str):
|
|
"""
|
|
:param func: The contracted function to call
|
|
:param ref: The reference to the function on the hub
|
|
:param name: An alias for the function
|
|
"""
|
|
self.__dict__.update(
|
|
getattr(func, "__dict__", {})
|
|
) # do this first so we later overwrite any conflicts
|
|
self.func = func
|
|
self.ref = ref
|
|
self.__name__ = name
|
|
self.signature = inspect.signature(self.func)
|
|
self._sig_errors = []
|
|
self.__wrapped__ = func
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
self.func(*args, **kwargs)
|
|
|
|
def __repr__(self):
|
|
return "<{} func={}.{}>".format(
|
|
self.__class__.__name__, self.func.__module__, self.__name__
|
|
)
|
|
|
|
|
|
class Contracted(Wrapper):
|
|
"""
|
|
This class wraps functions that have a contract associated with them
|
|
and executes the contract routines
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hub: "pop.hub.Hub",
|
|
contracts: List[Wrapper],
|
|
func: functools.partial,
|
|
ref: str,
|
|
name: str,
|
|
):
|
|
super().__init__(func, ref, name)
|
|
self.hub = hub
|
|
self.contracts = contracts or []
|
|
self._load_contracts()
|
|
|
|
def _get_contracts_by_type(self, contract_type: str = "pre") -> List[Wrapper]:
|
|
"""
|
|
:param contract_type: One of "call", "pre", "post", or "sig"
|
|
"""
|
|
matches = []
|
|
fn_contract_name = "{}_{}".format(contract_type, self.__name__)
|
|
for contract in self.contracts:
|
|
if hasattr(contract, fn_contract_name):
|
|
matches.append(getattr(contract, fn_contract_name))
|
|
if hasattr(contract, contract_type):
|
|
matches.append(getattr(contract, contract_type))
|
|
|
|
return matches
|
|
|
|
def _load_contracts(self):
|
|
# TODO:
|
|
# if Contracted - only allow regular pre/post
|
|
# if ContractedAsync - allow coroutines and functions
|
|
# if ContractedAsyncGen - allow coroutines and functions
|
|
|
|
self.contract_functions = {
|
|
"pre": self._get_contracts_by_type("pre"),
|
|
"call": self._get_contracts_by_type("call")[:1],
|
|
"post": self._get_contracts_by_type("post"),
|
|
}
|
|
# TODO: write test for stack-like behavior (reverse "pre")
|
|
self._has_contracts = (
|
|
sum([len(l) for l in self.contract_functions.values()]) > 0
|
|
)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
args = (self.hub,) + args
|
|
if not self._has_contracts:
|
|
return self.func(*args, **kwargs)
|
|
contract_context = ContractedContext(
|
|
self.func, args, kwargs, self.signature, self.ref
|
|
)
|
|
|
|
for fn in self.contract_functions["pre"]:
|
|
fn(contract_context)
|
|
if self.contract_functions["call"]:
|
|
ret = self.contract_functions["call"][0](contract_context)
|
|
else:
|
|
ret = self.func(*contract_context.args, **contract_context.kwargs)
|
|
for fn in self.contract_functions["post"]:
|
|
post_ret = fn(contract_context._replace(ret=ret))
|
|
if post_ret is not None:
|
|
ret = post_ret
|
|
|
|
return ret
|
|
|
|
|
|
class ContractedAsyncGen(Contracted):
|
|
async def __call__(self, *args, **kwargs):
|
|
args = (self.hub,) + args
|
|
if not self._has_contracts:
|
|
async for chunk in self.func(*args, **kwargs):
|
|
yield chunk
|
|
return
|
|
contract_context = ContractedContext(
|
|
self.func, args, kwargs, self.signature, self.ref
|
|
)
|
|
|
|
for fn in self.contract_functions["pre"]:
|
|
pre_ret = fn(contract_context)
|
|
if asyncio.iscoroutine(pre_ret):
|
|
await pre_ret
|
|
chunk = None
|
|
if self.contract_functions["call"]:
|
|
async for chunk in self.contract_functions["call"][0](contract_context):
|
|
yield chunk
|
|
else:
|
|
async for chunk in self.func(
|
|
*contract_context.args, **contract_context.kwargs
|
|
):
|
|
yield chunk
|
|
ret = chunk
|
|
for fn in self.contract_functions["post"]:
|
|
if isinstance(fn, ContractedAsync):
|
|
post_ret = await fn(contract_context._replace(ret=ret))
|
|
else:
|
|
post_ret = fn(contract_context._replace(ret=ret))
|
|
if post_ret is not None:
|
|
ret = post_ret
|
|
|
|
|
|
class ContractedAsync(Contracted):
|
|
async def __call__(self, *args, **kwargs):
|
|
args = (self.hub,) + args
|
|
if not self._has_contracts:
|
|
return await self.func(*args, **kwargs)
|
|
contract_context = ContractedContext(
|
|
self.func, args, kwargs, self.signature, self.ref
|
|
)
|
|
|
|
for fn in self.contract_functions["pre"]:
|
|
pre_ret = fn(contract_context)
|
|
if asyncio.iscoroutine(pre_ret):
|
|
await pre_ret
|
|
if self.contract_functions["call"]:
|
|
ret = await self.contract_functions["call"][0](contract_context)
|
|
else:
|
|
ret = await self.func(*contract_context.args, **contract_context.kwargs)
|
|
for fn in self.contract_functions["post"]:
|
|
if isinstance(fn, ContractedAsync):
|
|
post_ret = await fn(contract_context._replace(ret=ret))
|
|
else:
|
|
post_ret = fn(contract_context._replace(ret=ret))
|
|
if post_ret is not None:
|
|
ret = post_ret
|
|
|
|
return ret
|
|
|
|
|
|
def create_contracted(
|
|
hub: "pop.hub.Hub",
|
|
contracts: List[Wrapper],
|
|
func: functools.partial,
|
|
ref: str,
|
|
name: str,
|
|
) -> Contracted:
|
|
"""
|
|
Dynamically create the correct Contracted type
|
|
:param hub: The redistributed pop central hub
|
|
:param contracts: Contracts functions to add to the sub
|
|
:param func: The contracted function to call
|
|
:param ref: The reference to the function on the hub
|
|
:param name: The name of the module to get from the loader
|
|
"""
|
|
if asyncio.iscoroutinefunction(func):
|
|
return ContractedAsync(hub, contracts, func, ref, name)
|
|
elif inspect.isasyncgenfunction(func):
|
|
return ContractedAsyncGen(hub, contracts, func, ref, name)
|
|
else:
|
|
return Contracted(hub, contracts, func, ref, name)
|