first commit

This commit is contained in:
2020-11-03 18:30:14 -08:00
commit 31d8522470
1881 changed files with 345408 additions and 0 deletions

View File

@@ -0,0 +1,696 @@
import copy
import fnmatch
import logging
import re
import yaml
from . import DEFAULT_TARGET_DELIM
from . import args
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Dict, Iterable, List, Set
log = logging.getLogger(__name__)
class CaseInsensitiveDict(MutableMapping):
"""
Inspired by requests' case-insensitive dict implementation, but works with
non-string keys as well.
"""
def __init__(self, init=None, **kwargs):
"""
Force internal dict to be ordered to ensure a consistent iteration
order, irrespective of case.
"""
self._data = {}
self.update(init or {}, **kwargs)
def __len__(self):
return len(self._data)
def __setitem__(self, key, value):
# Store the case-sensitive key so it is available for dict iteration
self._data[to_lowercase(key)] = (key, value)
def __delitem__(self, key):
del self._data[to_lowercase(key)]
def __getitem__(self, key):
return self._data[to_lowercase(key)][1]
def __iter__(self):
return (item[0] for item in self._data.values())
def __eq__(self, rval):
if not isinstance(rval, Mapping):
# Comparing to non-mapping type (e.g. int) is always False
return False
return dict(self.items_lower()) == dict(CaseInsensitiveDict(rval).items_lower())
def __repr__(self):
return repr(dict(self.items()))
def items_lower(self):
"""
Returns a generator iterating over keys and values, with the keys all
being lowercase.
"""
return ((key, val[1]) for key, val in self._data.items())
def copy(self):
"""
Returns a copy of the object
"""
return CaseInsensitiveDict(self._data.items())
class ImmutableDict(Mapping):
"""
An abstract base class that implements the interface of a `dict` but is immutable.
Items can be retrieved via namespacing.
No values can be changed after initialization
"""
def __init__(self, init_: Dict[str, Any], **c_kwargs):
"""
:param init_: A dictionary from which to inherit data
"""
init_.update(**c_kwargs)
values = {}
for k, v in init_.items():
if isinstance(v, Dict):
values[k] = ImmutableDict(init_=v)
elif isinstance(v, (tuple, int, str, bytes)):
values[k] = v
elif isinstance(v, Iterable):
values[k] = tuple(v)
else:
values[k] = v
# __setattr__ is borked (on purpose) so we have to call it from super() right here
super().__setattr__("_ImmutableDict__store", values)
def __setattr__(self, k: str, v: Any):
raise TypeError(
f"{self.__class__.__name__} does not support attribute assignment"
)
def __getattr__(self, k: str):
return self.__store[k]
def __getitem__(self, k: str) -> Any:
return self.__store[k]
def __contains__(self, k: str) -> bool:
return k in self.__store
def __iter__(self):
return iter(self.__store)
def __len__(self) -> int:
return len(self.__store.keys())
def __copy__(self) -> Dict[str, Any]:
ret = {}
# Unpack IMAP items so that it's turtles all the way down
for k, v in self.__store.items():
if isinstance(v, ImmutableDict):
ret[k] = v.__copy__()
else:
ret[k] = v
return ret
def __repr__(self):
return repr(copy.copy(self))
class NamespaceDict(dict):
"""
A dictionary that can access it's string keys through the namespace
"""
def __init__(self, seq: Iterable = None, **kwargs):
"""
NamespaceDict() -> new empty namespaced dictionary
NamespaceDict(mapping) -> new namespaced dictionary initialized from a mapping object's
(key, value) pairs
NamespaceDict(iterable) -> new namespaced dictionary initialized as if via:
d = {}
for k, v in iterable:
d[k] = v
NamespaceDict(**kwargs) -> new namespaced dictionary initialized with the name=value pairs
in the keyword argument list. For example: NamespaceDict(one=1, two=2)
"""
if seq is None:
super().__init__(**kwargs)
else:
super().__init__(seq, **kwargs)
def __setattr__(self, k: str, v: Any):
if isinstance(v, dict) and not isinstance(v, NamespaceDict):
v = NamespaceDict(v)
self[k] = v
def __getattr__(self, k: str):
if k.startswith("_"):
return super().__getattribute__(k)
return self[k]
def __copy__(self):
return NamespaceDict(self.copy())
def __deepcopy__(self, memodict=None):
if memodict is None:
memodict = {}
return NamespaceDict(copy.deepcopy(self.copy(), memodict))
def __change_case(data, attr, preserve_dict_class=False):
"""
Calls data.attr() if data has an attribute/method called attr.
Processes data recursively if data is a Mapping or Sequence.
For Mapping, processes both keys and values.
"""
try:
return getattr(data, attr)()
except AttributeError:
pass
data_type = data.__class__
if isinstance(data, Mapping):
return (data_type if preserve_dict_class else dict)(
(
__change_case(key, attr, preserve_dict_class),
__change_case(val, attr, preserve_dict_class),
)
for key, val in data.items()
)
if isinstance(data, Sequence):
return data_type(
__change_case(item, attr, preserve_dict_class) for item in data
)
return data
def _remove_circular_refs(ob, _seen: Set = None):
"""
Generic method to remove circular references from objects.
This has been taken from author Martijn Pieters
https://stackoverflow.com/questions/44777369/
remove-circular-references-in-dicts-lists-tuples/44777477#44777477
:param ob: dict, list, typle, set, and frozenset
Standard python object
:param object _seen:
Object that has circular reference
:returns:
Cleaned Python object
"""
if _seen is None:
_seen = set()
if id(ob) in _seen:
# Here we caught a circular reference.
# Alert user and cleanup to continue.
log.exception(
"Caught a circular reference in data structure below."
"Cleaning and continuing execution.\n%r\n",
ob,
)
return None
_seen.add(id(ob))
res = ob
if isinstance(ob, dict):
res = {
_remove_circular_refs(k, _seen): _remove_circular_refs(v, _seen)
for k, v in ob.items()
}
elif isinstance(ob, (list, tuple, set, frozenset)):
res = type(ob)(_remove_circular_refs(v, _seen) for v in ob)
# remove id again; only *nested* references count
_seen.remove(id(ob))
return res
def compare_dicts(old: Dict = None, new: Dict = None) -> Dict[str, Dict]:
"""
Compare before and after results from various salt functions, returning a
dict describing the changes that were made.
"""
ret = {}
for key in set((new or {})).union((old or {})):
if key not in old:
# New key
ret[key] = {"old": "", "new": new[key]}
elif key not in new:
# Key removed
ret[key] = {"new": "", "old": old[key]}
elif new[key] != old[key]:
# Key modified
ret[key] = {"old": old[key], "new": new[key]}
return ret
def object_to_dict(obj) -> Dict:
"""
Convert an object to a dictionary
"""
if isinstance(obj, list) or isinstance(obj, tuple):
ret = []
for item in obj:
ret.append(object_to_dict(item))
elif hasattr(obj, "__dict__"):
ret = {}
for item in obj.__dict__:
if item.startswith("_"):
continue
ret[item] = object_to_dict(obj.__dict__[item])
else:
ret = obj
return ret
def is_dictlist(data: List) -> bool:
"""
Returns True if data is a list of one-element dicts (as found in many SLS
schemas), otherwise returns False
"""
if isinstance(data, list):
for element in data:
if isinstance(element, dict):
if len(element) != 1:
return False
else:
return False
return True
return False
def recursive_diff(
old: Iterable,
new: Iterable,
ignore_keys: List = None,
ignore_order: bool = False,
ignore_missing_keys: bool = False,
) -> Dict[str, Iterable]:
"""
Performs a recursive diff on mappings and/or iterables and returns the result
in a {'old': values, 'new': values}-style.
Compares dicts and sets unordered (obviously), OrderedDicts and Lists ordered
(but only if both ``old`` and ``new`` are of the same type),
all other Mapping types unordered, and all other iterables ordered.
:param old: Mapping or Iterable to compare from.
:param new: Mapping or Iterable to compare to.
:param ignore_keys: List of keys to ignore when comparing Mappings.
:param ignore_order: Compare ordered mapping/iterables as if they were unordered.
:param ignore_missing_keys: Do not return keys only present in ``old``
but missing in ``new``. Only works for regular dicts.
:return dict: Returns dict with keys 'old' and 'new' containing the differences.
"""
ignore_keys = ignore_keys or []
ret_old = copy.deepcopy(old)
ret_new = copy.deepcopy(new)
if isinstance(old, Mapping) and isinstance(new, Mapping) and not ignore_order:
append_old, append_new = [], []
if len(old) != len(new):
min_length = min(len(old), len(new))
# The list coercion is required for Py3
append_old = list(old.keys())[min_length:]
append_new = list(new.keys())[min_length:]
# Compare ordered
for (key_old, key_new) in zip(old, new):
if key_old == key_new:
if key_old in ignore_keys:
del ret_old[key_old]
del ret_new[key_new]
else:
res = recursive_diff(
old[key_old],
new[key_new],
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
del ret_old[key_old]
del ret_new[key_new]
else:
ret_old[key_old] = res["old"]
ret_new[key_new] = res["new"]
else:
if key_old in ignore_keys:
del ret_old[key_old]
if key_new in ignore_keys:
del ret_new[key_new]
# If the OrderedDicts were of inequal length, add the remaining key/values.
for item in append_old:
ret_old[item] = old[item]
for item in append_new:
ret_new[item] = new[item]
ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {}
elif isinstance(old, Mapping) and isinstance(new, Mapping):
# Compare unordered
for key in set(list(old) + list(new)):
if key in ignore_keys:
del ret_old[key]
del ret_new[key]
elif key in old and key in new:
res = recursive_diff(
old[key],
new[key],
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
del ret_old[key]
del ret_new[key]
else:
ret_old[key] = res["old"]
ret_new[key] = res["new"]
elif ignore_missing_keys and key in old:
del ret_old[key]
ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {}
elif isinstance(old, set) and isinstance(new, set):
ret = {"old": old - new, "new": new - old} if old - new or new - old else {}
elif (
isinstance(old, Iterable)
and not isinstance(old, str)
and isinstance(new, Iterable)
and not isinstance(new, str)
):
# Create a list so we can edit on an index-basis.
list_old = list(ret_old)
list_new = list(ret_new)
if ignore_order:
for item_old in old:
for item_new in new:
res = recursive_diff(
item_old,
item_new,
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res:
list_old.remove(item_old)
list_new.remove(item_new)
continue
else:
remove_indices = []
for index, (iter_old, iter_new) in enumerate(zip(old, new)):
res = recursive_diff(
iter_old,
iter_new,
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
remove_indices.append(index)
else:
list_old[index] = res["old"]
list_new[index] = res["new"]
for index in reversed(remove_indices):
list_old.pop(index)
list_new.pop(index)
# Instantiate a new whatever-it-was using the list as iterable source.
# This may not be the most optimized in way of speed and memory usage,
# but it will work for all iterable types.
ret = (
{"old": type(old)(list_old), "new": type(new)(list_new)}
if list_old or list_new
else {}
)
else:
ret = {} if old == new else {"old": ret_old, "new": ret_new}
return ret
def repack_dictlist(data, strict=False, recurse=False, key_cb=None, val_cb=None):
"""
Takes a list of one-element dicts (as found in many SLS schemas) and
repacks into a single dictionary.
"""
if isinstance(data, str):
try:
data = yaml.safe_load(data)
except yaml.parser.ParserError as err:
log.error(err)
return {}
if key_cb is None:
key_cb = lambda x: x
if val_cb is None:
val_cb = lambda x, y: y
valid_non_dict = (str, int, float)
if isinstance(data, list):
for element in data:
if isinstance(element, valid_non_dict):
continue
if isinstance(element, dict):
if len(element) != 1:
log.error(
"Invalid input for repack_dictlist: key/value pairs "
"must contain only one element (data passed: %s).",
element,
)
return {}
else:
log.error(
"Invalid input for repack_dictlist: element %s is "
"not a string/dict/numeric value",
element,
)
return {}
else:
log.error(
"Invalid input for repack_dictlist, data passed is not a list " "(%s)", data
)
return {}
ret = {}
for element in data:
if isinstance(element, valid_non_dict):
ret[key_cb(element)] = None
else:
key = next(iter(element))
val = element[key]
if is_dictlist(val):
if recurse:
ret[key_cb(key)] = repack_dictlist(val, recurse=recurse)
elif strict:
log.error(
"Invalid input for repack_dictlist: nested dictlist "
"found, but recurse is set to False"
)
return {}
else:
ret[key_cb(key)] = val_cb(key, val)
else:
ret[key_cb(key)] = val_cb(key, val)
return ret
def subdict_match(
data: Dict,
expr: str,
delimiter: str = DEFAULT_TARGET_DELIM,
regex_match: bool = False,
exact_match: bool = False,
) -> bool:
"""
Check for a match in a dictionary using a delimiter character to denote
levels of subdicts, and also allowing the delimiter character to be
matched. Thus, 'foo:bar:baz' will match data['foo'] == 'bar:baz' and
data['foo']['bar'] == 'baz'. The latter would take priority over the
former, as more deeply-nested matches are tried first.
"""
def _match(target, pattern):
target = str(target).lower()
pattern = str(pattern).lower()
if regex_match:
try:
return re.match(pattern, target)
except Exception: # pylint: disable=broad-except
log.error("Invalid regex '%s' in match", pattern)
return False
else:
return (
target == pattern if exact_match else fnmatch.fnmatch(target, pattern)
)
def _dict_match(target, pattern):
ret = False
wildcard = pattern.startswith("*:")
if wildcard:
pattern = pattern[2:]
if pattern == "*":
# We are just checking that the key exists
ret = True
if not ret and pattern in target:
# We might want to search for a key
ret = True
if not ret and subdict_match(
target, pattern, regex_match=regex_match, exact_match=exact_match
):
ret = True
if not ret and wildcard:
for key in target:
if isinstance(target[key], dict):
if _dict_match(target[key], pattern,):
return True
elif isinstance(target[key], list):
for item in target[key]:
if _match(item, pattern,):
return True
elif _match(target[key], pattern,):
return True
return ret
splits = expr.split(delimiter)
num_splits = len(splits)
if num_splits == 1:
# Delimiter not present, this can't possibly be a match
return False
# If we have 4 splits, then we have three delimiters. Thus, the indexes we
# want to use are 3, 2, and 1, in that order.
for idx in range(num_splits - 1, 0, -1):
key = delimiter.join(splits[:idx])
if key == "*":
# We are matching on everything under the top level, so we need to
# treat the match as the entire data being passed in
matchstr = expr
match = data
else:
matchstr = delimiter.join(splits[idx:])
match = traverse_dict_and_list(data, key, {}, delimiter=delimiter)
log.debug(
"Attempting to match '%s' in '%s' using delimiter '%s'",
matchstr,
key,
delimiter,
)
if match == {}:
continue
if isinstance(match, dict):
if _dict_match(match, matchstr):
return True
continue
if isinstance(match, (list, tuple)):
# We are matching a single component to a single list member
for member in match:
if isinstance(member, dict):
if _dict_match(member, matchstr,):
return True
if _match(member, matchstr):
return True
continue
if _match(match, matchstr):
return True
return False
def to_lowercase(data, preserve_dict_class=False):
"""
Recursively changes everything in data to lowercase.
"""
return __change_case(data, "lower", preserve_dict_class)
def to_uppercase(data, preserve_dict_class=False):
"""
Recursively changes everything in data to uppercase.
"""
return __change_case(data, "upper", preserve_dict_class)
def traverse_dict(
data: Dict, key: str, default: Any = None, delimiter: str = DEFAULT_TARGET_DELIM
):
"""
Traverse a dict using a colon-delimited (or otherwise delimited, using the
'delimiter' param) target string. The target 'foo:bar:baz' will return
data['foo']['bar']['baz'] if this value exists, and will otherwise return
the dict in the default argument.
"""
ptr = data
try:
for each in key.split(delimiter):
ptr = ptr[each]
except (KeyError, IndexError, TypeError):
# Encountered a non-indexable value in the middle of traversing
return default
return ptr
def traverse_dict_and_list(
data: Dict or List,
key: Any,
default: Any = None,
delimiter: str = DEFAULT_TARGET_DELIM,
):
"""
Traverse a dict or list using a colon-delimited (or otherwise delimited,
using the 'delimiter' param) target string. The target 'foo:bar:0' will
return data['foo']['bar'][0] if this value exists, and will otherwise
return the dict in the default argument.
Function will automatically determine the target type.
The target 'foo:bar:0' will return data['foo']['bar'][0] if data like
{'foo':{'bar':['baz']}} , if data like {'foo':{'bar':{'0':'baz'}}}
then return data['foo']['bar']['0']
"""
ptr = data
for each in key.split(delimiter):
if isinstance(ptr, list):
try:
idx = int(each)
except ValueError:
embed_match = False
# Index was not numeric, lets look at any embedded dicts
for embedded in (x for x in ptr if isinstance(x, dict)):
try:
ptr = embedded[each]
embed_match = True
break
except KeyError:
pass
if not embed_match:
# No embedded dicts matched, return the default
return default
else:
try:
ptr = ptr[idx]
except IndexError:
return default
else:
try:
ptr = ptr[each]
except KeyError:
# YAML-load the current key (catches integer/float dict keys)
try:
loaded_key = args.yamlify_arg(each)
except Exception: # pylint: disable=broad-except
return default
if loaded_key == each:
# After YAML-loading, the desired key is unchanged. This
# means that the KeyError caught above is a legitimate
# failure to match the desired key. Therefore, return the
# default.
return default
else:
# YAML-loading the key changed its value, so re-check with
# the loaded key. This is how we can match a numeric key
# with a string-based expression.
try:
ptr = ptr[loaded_key]
except (KeyError, TypeError):
return default
except TypeError:
return default
return ptr