first commit

This commit is contained in:
2020-11-03 18:30:14 -08:00
commit 31d8522470
1881 changed files with 345408 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""
Default values, to be imported elsewhere in Salt code
Do NOT, import any salt modules (salt.utils, salt.config, etc.) into this file,
as this may result in circular imports.
"""
# Default delimiter for multi-level traversal in targeting
DEFAULT_TARGET_DELIM = ":"
from . import data, differ, mysql, trim, update, xml

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""
This library makes it possible to introspect dataset and aggregate nodes
when it is instructed.
.. note::
The following examples with be expressed in YAML for convenience's sake:
- !aggr-scalar will refer to Scalar python function
- !aggr-map will refer to Map python object
- !aggr-seq will refer for Sequence python object
How to instructs merging
------------------------
This yaml document has duplicate keys:
.. code-block:: yaml
foo: !aggr-scalar first
foo: !aggr-scalar second
bar: !aggr-map {first: foo}
bar: !aggr-map {second: bar}
baz: !aggr-scalar 42
but tagged values instruct Salt that overlapping values they can be merged
together:
.. code-block:: yaml
foo: !aggr-seq [first, second]
bar: !aggr-map {first: foo, second: bar}
baz: !aggr-seq [42]
Default merge strategy is keep untouched
----------------------------------------
For example, this yaml document still has duplicate keys, but does not
instruct aggregation:
.. code-block:: yaml
foo: first
foo: second
bar: {first: foo}
bar: {second: bar}
baz: 42
So the late found values prevail:
.. code-block:: yaml
foo: second
bar: {second: bar}
baz: 42
Limitations
-----------
Aggregation is permitted between tagged objects that share the same type.
If not, the default merge strategy prevails.
For example, these examples:
.. code-block:: yaml
foo: {first: value}
foo: !aggr-map {second: value}
bar: !aggr-map {first: value}
bar: 42
baz: !aggr-seq [42]
baz: [fail]
qux: 42
qux: !aggr-scalar fail
are interpreted like this:
.. code-block:: yaml
foo: !aggr-map{second: value}
bar: 42
baz: [fail]
qux: !aggr-seq [fail]
Introspection
-------------
TODO: write this part
"""
from typing import Iterable, Tuple
import copy
import logging
import collections
__all__ = ["aggregate", "Aggregate", "Map", "Scalar", "Sequence"]
log = logging.getLogger(__name__)
class Aggregate(object):
"""
Aggregation base.
"""
class Map(collections.OrderedDict, Aggregate):
"""
Map aggregation.
"""
class Sequence(list, Aggregate):
"""
Sequence aggregation.
"""
def Scalar(obj):
"""
Shortcut for Sequence creation
>>> Scalar('foo') == Sequence(['foo'])
True
"""
return Sequence([obj])
def levelise(level: bool or int or Iterable) -> Tuple[bool, bool or int]:
"""
Describe which levels are allowed to do deep merging.
level can be:
True
all levels are True
False
all levels are False
an int
only the first levels are True, the others are False
a sequence
it describes which levels are True, it can be:
* a list of bool and int values
* a string of 0 and 1 characters
"""
if not level: # False, 0, [] ...
return False, False
if level is True:
return True, True
if isinstance(level, int):
return True, level - 1
try: # a sequence
deep, subs = int(level[0]), level[1:]
return bool(deep), subs
except Exception as error: # pylint: disable=broad-except
log.warning(error)
raise
def mark(
obj: object, map_class: object = Map, sequence_class: object = Sequence
) -> object:
"""
Convert obj into an Aggregate instance
"""
if isinstance(obj, Aggregate):
return obj
if isinstance(obj, dict):
return map_class(obj)
if isinstance(obj, (list, tuple, set)):
return sequence_class(obj)
else:
return sequence_class([obj])
def aggregate(
obj_a,
obj_b,
level: bool or int = False,
map_class: object = Map,
sequence_class: object = Sequence,
):
"""
Merge obj_b into obj_a.
>>> aggregate('first', 'second', True) == ['first', 'second']
True
"""
deep, subdeep = levelise(level)
if deep:
obj_a = mark(obj_a, map_class=map_class, sequence_class=sequence_class)
obj_b = mark(obj_b, map_class=map_class, sequence_class=sequence_class)
if isinstance(obj_a, dict) and isinstance(obj_b, dict):
if isinstance(obj_a, Aggregate) and isinstance(obj_b, Aggregate):
# deep merging is more or less a.update(obj_b)
response = copy.copy(obj_a)
else:
# introspection on obj_b keys only
response = copy.copy(obj_b)
for key, value in obj_b.items():
if key in obj_a:
value = aggregate(obj_a[key], value, subdeep, map_class, sequence_class)
response[key] = value
return response
if isinstance(obj_a, Sequence) and isinstance(obj_b, Sequence):
response = obj_a.__class__(obj_a[:])
for value in obj_b:
if value not in obj_a:
response.append(value)
return response
response = copy.copy(obj_b)
if isinstance(obj_a, Aggregate) or isinstance(obj_b, Aggregate):
log.info("only one value marked as aggregate. keep `obj_b` value")
return response
log.debug("no value marked as aggregate. keep `obj_b` value")
return response

View File

@@ -0,0 +1,89 @@
import yaml
def yamlify_arg(arg):
"""
yaml.safe_load the arg
"""
if not isinstance(arg, str):
return arg
# YAML loads empty (or all whitespace) strings as None:
#
# >>> import yaml
# >>> yaml.load('') is None
# True
# >>> yaml.load(' ') is None
# True
#
# Similarly, YAML document start/end markers would not load properly if
# passed through PyYAML, as loading '---' results in None and '...' raises
# an exception.
#
# Therefore, skip YAML loading for these cases and just return the string
# that was passed in.
if arg.strip() in ("", "---", "..."):
return arg
elif "_" in arg and all([x in "0123456789_" for x in arg.strip()]):
# When the stripped string includes just digits and underscores, the
# underscores are ignored and the digits are combined together and
# loaded as an int. We don't want that, so return the original value.
return arg
else:
if any(np_char in arg for np_char in ("\t", "\r", "\n")):
# Don't mess with this CLI arg, since it has one or more
# non-printable whitespace char. Since the CSafeLoader will
# sanitize these chars rather than raise an exception, just
# skip YAML loading of this argument and keep the argument as
# passed on the CLI.
return arg
try:
original_arg = arg
if "#" in arg:
# Only yamlify if it parses into a non-string type, to prevent
# loss of content due to # as comment character
parsed_arg = yaml.safe_load(arg)
if isinstance(parsed_arg, str) or parsed_arg is None:
return arg
return parsed_arg
if arg == "None":
arg = None
else:
arg = yaml.safe_load(arg)
if isinstance(arg, dict):
# dicts must be wrapped in curly braces
if isinstance(original_arg, str) and not original_arg.startswith("{"):
return original_arg
else:
return arg
elif isinstance(arg, list):
# lists must be wrapped in brackets
if isinstance(original_arg, str) and not original_arg.startswith("["):
return original_arg
else:
return arg
elif arg is None or isinstance(arg, (list, float, int, str)):
# yaml.safe_load will load '|' and '!' as '', don't let it do that.
if arg == "" and original_arg in ("|", "!"):
return original_arg
# yaml.safe_load will treat '#' as a comment, so a value of '#'
# will become None. Keep this value from being stomped as well.
elif arg is None and original_arg.strip().startswith("#"):
return original_arg
# Other times, yaml.safe_load will load '!' as None. Prevent that.
elif arg is None and original_arg == "!":
return original_arg
else:
return arg
else:
# we don't support this type
return original_arg
except Exception: # pylint: disable=broad-except
# In case anything goes wrong...
return original_arg

View File

@@ -0,0 +1,696 @@
import copy
import fnmatch
import logging
import re
import yaml
from . import DEFAULT_TARGET_DELIM
from . import args
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, Dict, Iterable, List, Set
log = logging.getLogger(__name__)
class CaseInsensitiveDict(MutableMapping):
"""
Inspired by requests' case-insensitive dict implementation, but works with
non-string keys as well.
"""
def __init__(self, init=None, **kwargs):
"""
Force internal dict to be ordered to ensure a consistent iteration
order, irrespective of case.
"""
self._data = {}
self.update(init or {}, **kwargs)
def __len__(self):
return len(self._data)
def __setitem__(self, key, value):
# Store the case-sensitive key so it is available for dict iteration
self._data[to_lowercase(key)] = (key, value)
def __delitem__(self, key):
del self._data[to_lowercase(key)]
def __getitem__(self, key):
return self._data[to_lowercase(key)][1]
def __iter__(self):
return (item[0] for item in self._data.values())
def __eq__(self, rval):
if not isinstance(rval, Mapping):
# Comparing to non-mapping type (e.g. int) is always False
return False
return dict(self.items_lower()) == dict(CaseInsensitiveDict(rval).items_lower())
def __repr__(self):
return repr(dict(self.items()))
def items_lower(self):
"""
Returns a generator iterating over keys and values, with the keys all
being lowercase.
"""
return ((key, val[1]) for key, val in self._data.items())
def copy(self):
"""
Returns a copy of the object
"""
return CaseInsensitiveDict(self._data.items())
class ImmutableDict(Mapping):
"""
An abstract base class that implements the interface of a `dict` but is immutable.
Items can be retrieved via namespacing.
No values can be changed after initialization
"""
def __init__(self, init_: Dict[str, Any], **c_kwargs):
"""
:param init_: A dictionary from which to inherit data
"""
init_.update(**c_kwargs)
values = {}
for k, v in init_.items():
if isinstance(v, Dict):
values[k] = ImmutableDict(init_=v)
elif isinstance(v, (tuple, int, str, bytes)):
values[k] = v
elif isinstance(v, Iterable):
values[k] = tuple(v)
else:
values[k] = v
# __setattr__ is borked (on purpose) so we have to call it from super() right here
super().__setattr__("_ImmutableDict__store", values)
def __setattr__(self, k: str, v: Any):
raise TypeError(
f"{self.__class__.__name__} does not support attribute assignment"
)
def __getattr__(self, k: str):
return self.__store[k]
def __getitem__(self, k: str) -> Any:
return self.__store[k]
def __contains__(self, k: str) -> bool:
return k in self.__store
def __iter__(self):
return iter(self.__store)
def __len__(self) -> int:
return len(self.__store.keys())
def __copy__(self) -> Dict[str, Any]:
ret = {}
# Unpack IMAP items so that it's turtles all the way down
for k, v in self.__store.items():
if isinstance(v, ImmutableDict):
ret[k] = v.__copy__()
else:
ret[k] = v
return ret
def __repr__(self):
return repr(copy.copy(self))
class NamespaceDict(dict):
"""
A dictionary that can access it's string keys through the namespace
"""
def __init__(self, seq: Iterable = None, **kwargs):
"""
NamespaceDict() -> new empty namespaced dictionary
NamespaceDict(mapping) -> new namespaced dictionary initialized from a mapping object's
(key, value) pairs
NamespaceDict(iterable) -> new namespaced dictionary initialized as if via:
d = {}
for k, v in iterable:
d[k] = v
NamespaceDict(**kwargs) -> new namespaced dictionary initialized with the name=value pairs
in the keyword argument list. For example: NamespaceDict(one=1, two=2)
"""
if seq is None:
super().__init__(**kwargs)
else:
super().__init__(seq, **kwargs)
def __setattr__(self, k: str, v: Any):
if isinstance(v, dict) and not isinstance(v, NamespaceDict):
v = NamespaceDict(v)
self[k] = v
def __getattr__(self, k: str):
if k.startswith("_"):
return super().__getattribute__(k)
return self[k]
def __copy__(self):
return NamespaceDict(self.copy())
def __deepcopy__(self, memodict=None):
if memodict is None:
memodict = {}
return NamespaceDict(copy.deepcopy(self.copy(), memodict))
def __change_case(data, attr, preserve_dict_class=False):
"""
Calls data.attr() if data has an attribute/method called attr.
Processes data recursively if data is a Mapping or Sequence.
For Mapping, processes both keys and values.
"""
try:
return getattr(data, attr)()
except AttributeError:
pass
data_type = data.__class__
if isinstance(data, Mapping):
return (data_type if preserve_dict_class else dict)(
(
__change_case(key, attr, preserve_dict_class),
__change_case(val, attr, preserve_dict_class),
)
for key, val in data.items()
)
if isinstance(data, Sequence):
return data_type(
__change_case(item, attr, preserve_dict_class) for item in data
)
return data
def _remove_circular_refs(ob, _seen: Set = None):
"""
Generic method to remove circular references from objects.
This has been taken from author Martijn Pieters
https://stackoverflow.com/questions/44777369/
remove-circular-references-in-dicts-lists-tuples/44777477#44777477
:param ob: dict, list, typle, set, and frozenset
Standard python object
:param object _seen:
Object that has circular reference
:returns:
Cleaned Python object
"""
if _seen is None:
_seen = set()
if id(ob) in _seen:
# Here we caught a circular reference.
# Alert user and cleanup to continue.
log.exception(
"Caught a circular reference in data structure below."
"Cleaning and continuing execution.\n%r\n",
ob,
)
return None
_seen.add(id(ob))
res = ob
if isinstance(ob, dict):
res = {
_remove_circular_refs(k, _seen): _remove_circular_refs(v, _seen)
for k, v in ob.items()
}
elif isinstance(ob, (list, tuple, set, frozenset)):
res = type(ob)(_remove_circular_refs(v, _seen) for v in ob)
# remove id again; only *nested* references count
_seen.remove(id(ob))
return res
def compare_dicts(old: Dict = None, new: Dict = None) -> Dict[str, Dict]:
"""
Compare before and after results from various salt functions, returning a
dict describing the changes that were made.
"""
ret = {}
for key in set((new or {})).union((old or {})):
if key not in old:
# New key
ret[key] = {"old": "", "new": new[key]}
elif key not in new:
# Key removed
ret[key] = {"new": "", "old": old[key]}
elif new[key] != old[key]:
# Key modified
ret[key] = {"old": old[key], "new": new[key]}
return ret
def object_to_dict(obj) -> Dict:
"""
Convert an object to a dictionary
"""
if isinstance(obj, list) or isinstance(obj, tuple):
ret = []
for item in obj:
ret.append(object_to_dict(item))
elif hasattr(obj, "__dict__"):
ret = {}
for item in obj.__dict__:
if item.startswith("_"):
continue
ret[item] = object_to_dict(obj.__dict__[item])
else:
ret = obj
return ret
def is_dictlist(data: List) -> bool:
"""
Returns True if data is a list of one-element dicts (as found in many SLS
schemas), otherwise returns False
"""
if isinstance(data, list):
for element in data:
if isinstance(element, dict):
if len(element) != 1:
return False
else:
return False
return True
return False
def recursive_diff(
old: Iterable,
new: Iterable,
ignore_keys: List = None,
ignore_order: bool = False,
ignore_missing_keys: bool = False,
) -> Dict[str, Iterable]:
"""
Performs a recursive diff on mappings and/or iterables and returns the result
in a {'old': values, 'new': values}-style.
Compares dicts and sets unordered (obviously), OrderedDicts and Lists ordered
(but only if both ``old`` and ``new`` are of the same type),
all other Mapping types unordered, and all other iterables ordered.
:param old: Mapping or Iterable to compare from.
:param new: Mapping or Iterable to compare to.
:param ignore_keys: List of keys to ignore when comparing Mappings.
:param ignore_order: Compare ordered mapping/iterables as if they were unordered.
:param ignore_missing_keys: Do not return keys only present in ``old``
but missing in ``new``. Only works for regular dicts.
:return dict: Returns dict with keys 'old' and 'new' containing the differences.
"""
ignore_keys = ignore_keys or []
ret_old = copy.deepcopy(old)
ret_new = copy.deepcopy(new)
if isinstance(old, Mapping) and isinstance(new, Mapping) and not ignore_order:
append_old, append_new = [], []
if len(old) != len(new):
min_length = min(len(old), len(new))
# The list coercion is required for Py3
append_old = list(old.keys())[min_length:]
append_new = list(new.keys())[min_length:]
# Compare ordered
for (key_old, key_new) in zip(old, new):
if key_old == key_new:
if key_old in ignore_keys:
del ret_old[key_old]
del ret_new[key_new]
else:
res = recursive_diff(
old[key_old],
new[key_new],
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
del ret_old[key_old]
del ret_new[key_new]
else:
ret_old[key_old] = res["old"]
ret_new[key_new] = res["new"]
else:
if key_old in ignore_keys:
del ret_old[key_old]
if key_new in ignore_keys:
del ret_new[key_new]
# If the OrderedDicts were of inequal length, add the remaining key/values.
for item in append_old:
ret_old[item] = old[item]
for item in append_new:
ret_new[item] = new[item]
ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {}
elif isinstance(old, Mapping) and isinstance(new, Mapping):
# Compare unordered
for key in set(list(old) + list(new)):
if key in ignore_keys:
del ret_old[key]
del ret_new[key]
elif key in old and key in new:
res = recursive_diff(
old[key],
new[key],
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
del ret_old[key]
del ret_new[key]
else:
ret_old[key] = res["old"]
ret_new[key] = res["new"]
elif ignore_missing_keys and key in old:
del ret_old[key]
ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {}
elif isinstance(old, set) and isinstance(new, set):
ret = {"old": old - new, "new": new - old} if old - new or new - old else {}
elif (
isinstance(old, Iterable)
and not isinstance(old, str)
and isinstance(new, Iterable)
and not isinstance(new, str)
):
# Create a list so we can edit on an index-basis.
list_old = list(ret_old)
list_new = list(ret_new)
if ignore_order:
for item_old in old:
for item_new in new:
res = recursive_diff(
item_old,
item_new,
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res:
list_old.remove(item_old)
list_new.remove(item_new)
continue
else:
remove_indices = []
for index, (iter_old, iter_new) in enumerate(zip(old, new)):
res = recursive_diff(
iter_old,
iter_new,
ignore_keys=ignore_keys,
ignore_order=ignore_order,
ignore_missing_keys=ignore_missing_keys,
)
if not res: # Equal
remove_indices.append(index)
else:
list_old[index] = res["old"]
list_new[index] = res["new"]
for index in reversed(remove_indices):
list_old.pop(index)
list_new.pop(index)
# Instantiate a new whatever-it-was using the list as iterable source.
# This may not be the most optimized in way of speed and memory usage,
# but it will work for all iterable types.
ret = (
{"old": type(old)(list_old), "new": type(new)(list_new)}
if list_old or list_new
else {}
)
else:
ret = {} if old == new else {"old": ret_old, "new": ret_new}
return ret
def repack_dictlist(data, strict=False, recurse=False, key_cb=None, val_cb=None):
"""
Takes a list of one-element dicts (as found in many SLS schemas) and
repacks into a single dictionary.
"""
if isinstance(data, str):
try:
data = yaml.safe_load(data)
except yaml.parser.ParserError as err:
log.error(err)
return {}
if key_cb is None:
key_cb = lambda x: x
if val_cb is None:
val_cb = lambda x, y: y
valid_non_dict = (str, int, float)
if isinstance(data, list):
for element in data:
if isinstance(element, valid_non_dict):
continue
if isinstance(element, dict):
if len(element) != 1:
log.error(
"Invalid input for repack_dictlist: key/value pairs "
"must contain only one element (data passed: %s).",
element,
)
return {}
else:
log.error(
"Invalid input for repack_dictlist: element %s is "
"not a string/dict/numeric value",
element,
)
return {}
else:
log.error(
"Invalid input for repack_dictlist, data passed is not a list " "(%s)", data
)
return {}
ret = {}
for element in data:
if isinstance(element, valid_non_dict):
ret[key_cb(element)] = None
else:
key = next(iter(element))
val = element[key]
if is_dictlist(val):
if recurse:
ret[key_cb(key)] = repack_dictlist(val, recurse=recurse)
elif strict:
log.error(
"Invalid input for repack_dictlist: nested dictlist "
"found, but recurse is set to False"
)
return {}
else:
ret[key_cb(key)] = val_cb(key, val)
else:
ret[key_cb(key)] = val_cb(key, val)
return ret
def subdict_match(
data: Dict,
expr: str,
delimiter: str = DEFAULT_TARGET_DELIM,
regex_match: bool = False,
exact_match: bool = False,
) -> bool:
"""
Check for a match in a dictionary using a delimiter character to denote
levels of subdicts, and also allowing the delimiter character to be
matched. Thus, 'foo:bar:baz' will match data['foo'] == 'bar:baz' and
data['foo']['bar'] == 'baz'. The latter would take priority over the
former, as more deeply-nested matches are tried first.
"""
def _match(target, pattern):
target = str(target).lower()
pattern = str(pattern).lower()
if regex_match:
try:
return re.match(pattern, target)
except Exception: # pylint: disable=broad-except
log.error("Invalid regex '%s' in match", pattern)
return False
else:
return (
target == pattern if exact_match else fnmatch.fnmatch(target, pattern)
)
def _dict_match(target, pattern):
ret = False
wildcard = pattern.startswith("*:")
if wildcard:
pattern = pattern[2:]
if pattern == "*":
# We are just checking that the key exists
ret = True
if not ret and pattern in target:
# We might want to search for a key
ret = True
if not ret and subdict_match(
target, pattern, regex_match=regex_match, exact_match=exact_match
):
ret = True
if not ret and wildcard:
for key in target:
if isinstance(target[key], dict):
if _dict_match(target[key], pattern,):
return True
elif isinstance(target[key], list):
for item in target[key]:
if _match(item, pattern,):
return True
elif _match(target[key], pattern,):
return True
return ret
splits = expr.split(delimiter)
num_splits = len(splits)
if num_splits == 1:
# Delimiter not present, this can't possibly be a match
return False
# If we have 4 splits, then we have three delimiters. Thus, the indexes we
# want to use are 3, 2, and 1, in that order.
for idx in range(num_splits - 1, 0, -1):
key = delimiter.join(splits[:idx])
if key == "*":
# We are matching on everything under the top level, so we need to
# treat the match as the entire data being passed in
matchstr = expr
match = data
else:
matchstr = delimiter.join(splits[idx:])
match = traverse_dict_and_list(data, key, {}, delimiter=delimiter)
log.debug(
"Attempting to match '%s' in '%s' using delimiter '%s'",
matchstr,
key,
delimiter,
)
if match == {}:
continue
if isinstance(match, dict):
if _dict_match(match, matchstr):
return True
continue
if isinstance(match, (list, tuple)):
# We are matching a single component to a single list member
for member in match:
if isinstance(member, dict):
if _dict_match(member, matchstr,):
return True
if _match(member, matchstr):
return True
continue
if _match(match, matchstr):
return True
return False
def to_lowercase(data, preserve_dict_class=False):
"""
Recursively changes everything in data to lowercase.
"""
return __change_case(data, "lower", preserve_dict_class)
def to_uppercase(data, preserve_dict_class=False):
"""
Recursively changes everything in data to uppercase.
"""
return __change_case(data, "upper", preserve_dict_class)
def traverse_dict(
data: Dict, key: str, default: Any = None, delimiter: str = DEFAULT_TARGET_DELIM
):
"""
Traverse a dict using a colon-delimited (or otherwise delimited, using the
'delimiter' param) target string. The target 'foo:bar:baz' will return
data['foo']['bar']['baz'] if this value exists, and will otherwise return
the dict in the default argument.
"""
ptr = data
try:
for each in key.split(delimiter):
ptr = ptr[each]
except (KeyError, IndexError, TypeError):
# Encountered a non-indexable value in the middle of traversing
return default
return ptr
def traverse_dict_and_list(
data: Dict or List,
key: Any,
default: Any = None,
delimiter: str = DEFAULT_TARGET_DELIM,
):
"""
Traverse a dict or list using a colon-delimited (or otherwise delimited,
using the 'delimiter' param) target string. The target 'foo:bar:0' will
return data['foo']['bar'][0] if this value exists, and will otherwise
return the dict in the default argument.
Function will automatically determine the target type.
The target 'foo:bar:0' will return data['foo']['bar'][0] if data like
{'foo':{'bar':['baz']}} , if data like {'foo':{'bar':{'0':'baz'}}}
then return data['foo']['bar']['0']
"""
ptr = data
for each in key.split(delimiter):
if isinstance(ptr, list):
try:
idx = int(each)
except ValueError:
embed_match = False
# Index was not numeric, lets look at any embedded dicts
for embedded in (x for x in ptr if isinstance(x, dict)):
try:
ptr = embedded[each]
embed_match = True
break
except KeyError:
pass
if not embed_match:
# No embedded dicts matched, return the default
return default
else:
try:
ptr = ptr[idx]
except IndexError:
return default
else:
try:
ptr = ptr[each]
except KeyError:
# YAML-load the current key (catches integer/float dict keys)
try:
loaded_key = args.yamlify_arg(each)
except Exception: # pylint: disable=broad-except
return default
if loaded_key == each:
# After YAML-loading, the desired key is unchanged. This
# means that the KeyError caught above is a legitimate
# failure to match the desired key. Therefore, return the
# default.
return default
else:
# YAML-loading the key changed its value, so re-check with
# the loaded key. This is how we can match a numeric key
# with a string-based expression.
try:
ptr = ptr[loaded_key]
except (KeyError, TypeError):
return default
except TypeError:
return default
return ptr

View File

@@ -0,0 +1,415 @@
# -*- coding: utf-8 -*-
"""
Calculate the difference between two dictionaries as:
(1) items added
(2) items removed
(3) keys same in both but changed values
(4) keys same in both and unchanged values
Originally posted at http://stackoverflow.com/questions/1165352/fast-comparison-between-two-python-dictionary/1165552#1165552
Available at repository: https://github.com/hughdbrown/dictdiffer
Added the ability to recursively compare dictionaries
"""
import copy
from collections import Mapping
from typing import Any, Dict, List, Text, Set
def diff(current_dict, past_dict):
return DictDiffer(current_dict, past_dict)
class DictDiffer:
"""
Calculate the difference between two dictionaries as:
(1) items added
(2) items removed
(3) keys same in both but changed values
(4) keys same in both and unchanged values
"""
def __init__(self, current_dict: Dict, past_dict: Dict):
self.current_dict, self.past_dict = current_dict, past_dict
self.set_current, self.set_past = set(list(current_dict)), set(list(past_dict))
self.intersect = self.set_current.intersection(self.set_past)
def added(self) -> Set:
return self.set_current - self.intersect
def removed(self) -> Set:
return self.set_past - self.intersect
def changed(self) -> Set:
return set(
o for o in self.intersect if self.past_dict[o] != self.current_dict[o]
)
def unchanged(self) -> Set:
return set(
o for o in self.intersect if self.past_dict[o] == self.current_dict[o]
)
class RecursiveDictDiffer(DictDiffer):
"""
Calculates a recursive diff between the current_dict and the past_dict
creating a diff in the format
{'new': new_value, 'old': old_value}
It recursively searches differences in common keys whose values are
dictionaries creating a diff dict in the format
{'common_key' : {'new': new_value, 'old': old_value}
The class overrides all DictDiffer methods, returning lists of keys and
subkeys using the . notation (i.e 'common_key1.common_key2.changed_key')
The class provides access to:
(1) the added, removed, changes keys and subkeys (using the . notation)
``added``, ``removed``, ``changed`` methods
(2) the diffs in the format aboce (diff property)
``diffs`` property
(3) a dict with the new changed values only (new_values property)
``new_values`` property
(4) a dict with the old changed values only (old_values property)
``old_values`` property
(5) a string representation of the changes in the format:
``changes_str`` property
Note:
The <_null_> value is a reserved value
.. code-block:: text
common_key1:
common_key2:
changed_key1 from '<old_str>' to '<new_str>'
changed_key2 from '[<old_elem1>, ..]' to '[<new_elem1>, ..]'
common_key3:
changed_key3 from <old_int> to <new_int>
"""
NONE_VALUE = "<_null_>"
def __init__(self, past_dict: Dict, current_dict: Dict, ignore_missing_keys: bool):
"""
past_dict
Past dictionary.
current_dict
Current dictionary.
ignore_missing_keys
Flag specifying whether to ignore keys that no longer exist in the
current_dict, but exist in the past_dict. If true, the diff will
not contain the missing keys.
"""
super(RecursiveDictDiffer, self).__init__(current_dict, past_dict)
self._diffs = self._get_diffs(
self.current_dict, self.past_dict, ignore_missing_keys
)
# Ignores unet values when assessing the changes
self.ignore_unset_values = True
@classmethod
def _get_diffs(cls, dict1, dict2, ignore_missing_keys) -> Dict:
"""
Returns a dict with the differences between dict1 and dict2
Notes:
Keys that only exist in dict2 are not included in the diff if
ignore_missing_keys is True, otherwise they are
Simple compares are done on lists
"""
ret_dict = {}
for p in dict1.keys():
if p not in dict2:
ret_dict.update({p: {"new": dict1[p], "old": cls.NONE_VALUE}})
elif dict1[p] != dict2[p]:
if isinstance(dict1[p], dict) and isinstance(dict2[p], dict):
sub_diff_dict = cls._get_diffs(
dict1[p], dict2[p], ignore_missing_keys
)
if sub_diff_dict:
ret_dict.update({p: sub_diff_dict})
else:
ret_dict.update({p: {"new": dict1[p], "old": dict2[p]}})
if not ignore_missing_keys:
for p in dict2.keys():
if p not in dict1.keys():
ret_dict.update({p: {"new": cls.NONE_VALUE, "old": dict2[p]}})
return ret_dict
@classmethod
def _get_values(cls, diff_dict: Dict, type_: str = "new") -> Dict:
"""
Returns a dictionaries with the 'new' values in a diff dict.
type_
Which values to return, 'new' or 'old'
"""
ret_dict = {}
for p in diff_dict.keys():
if type_ in diff_dict[p].keys():
ret_dict.update({p: diff_dict[p][type_]})
else:
ret_dict.update({p: cls._get_values(diff_dict[p], type_=type_)})
return ret_dict
@classmethod
def _get_changes(cls, diff_dict: Dict) -> Dict:
"""
Returns a list of string message with the differences in a diff dict.
Each inner difference is tabulated two space deeper
"""
changes_strings = []
for p in sorted(diff_dict.keys()):
if sorted(diff_dict[p].keys()) == ["new", "old"]:
# Some string formatting
old_value = diff_dict[p]["old"]
if diff_dict[p]["old"] == cls.NONE_VALUE:
old_value = "nothing"
elif isinstance(diff_dict[p]["old"], Text):
old_value = "'{0}'".format(diff_dict[p]["old"])
elif isinstance(diff_dict[p]["old"], list):
old_value = "'{0}'".format(", ".join(diff_dict[p]["old"]))
new_value = diff_dict[p]["new"]
if diff_dict[p]["new"] == cls.NONE_VALUE:
new_value = "nothing"
elif isinstance(diff_dict[p]["new"], Text):
new_value = "'{0}'".format(diff_dict[p]["new"])
elif isinstance(diff_dict[p]["new"], list):
new_value = "'{0}'".format(", ".join(diff_dict[p]["new"]))
changes_strings.append(
"{0} from {1} to {2}".format(p, old_value, new_value)
)
else:
sub_changes = cls._get_changes(diff_dict[p])
if sub_changes:
changes_strings.append("{0}:".format(p))
changes_strings.extend([" {0}".format(c) for c in sub_changes])
return changes_strings
def added(self) -> Set[str]:
"""
Returns all keys that have been added.
If the keys are in child dictionaries they will be represented with
. notation
"""
def _added(diffs, prefix):
keys = []
for key in diffs.keys():
if isinstance(diffs[key], dict) and "old" not in diffs[key]:
keys.extend(
_added(diffs[key], prefix="{0}{1}.".format(prefix, key))
)
elif diffs[key]["old"] == self.NONE_VALUE:
if isinstance(diffs[key]["new"], dict):
keys.extend(
_added(
diffs[key]["new"], prefix="{0}{1}.".format(prefix, key)
)
)
else:
keys.append("{0}{1}".format(prefix, key))
return keys
return sorted(_added(self._diffs, prefix=""))
def removed(self) -> Set[str]:
"""
Returns all keys that have been removed.
If the keys are in child dictionaries they will be represented with
. notation
"""
def _removed(diffs, prefix):
keys = []
for key in diffs.keys():
if isinstance(diffs[key], dict) and "old" not in diffs[key]:
keys.extend(
_removed(diffs[key], prefix="{0}{1}.".format(prefix, key))
)
elif diffs[key]["new"] == self.NONE_VALUE:
keys.append("{0}{1}".format(prefix, key))
elif isinstance(diffs[key]["new"], dict):
keys.extend(
_removed(
diffs[key]["new"], prefix="{0}{1}.".format(prefix, key)
)
)
return keys
return sorted(_removed(self._diffs, prefix=""))
def changed(self) -> Set[str]:
"""
Returns all keys that have been changed.
If the keys are in child dictionaries they will be represented with
. notation
"""
def _changed(diffs, prefix):
keys = []
for key in diffs.keys():
if not isinstance(diffs[key], dict):
continue
if isinstance(diffs[key], dict) and "old" not in diffs[key]:
keys.extend(
_changed(diffs[key], prefix="{0}{1}.".format(prefix, key))
)
continue
if self.ignore_unset_values:
if (
"old" in diffs[key]
and "new" in diffs[key]
and diffs[key]["old"] != self.NONE_VALUE
and diffs[key]["new"] != self.NONE_VALUE
):
if isinstance(diffs[key]["new"], dict):
keys.extend(
_changed(
diffs[key]["new"],
prefix="{0}{1}.".format(prefix, key),
)
)
else:
keys.append("{0}{1}".format(prefix, key))
elif isinstance(diffs[key], dict):
keys.extend(
_changed(diffs[key], prefix="{0}{1}.".format(prefix, key))
)
else:
if "old" in diffs[key] and "new" in diffs[key]:
if isinstance(diffs[key]["new"], dict):
keys.extend(
_changed(
diffs[key]["new"],
prefix="{0}{1}.".format(prefix, key),
)
)
else:
keys.append("{0}{1}".format(prefix, key))
elif isinstance(diffs[key], dict):
keys.extend(
_changed(diffs[key], prefix="{0}{1}.".format(prefix, key))
)
return keys
return sorted(_changed(self._diffs, prefix=""))
def unchanged(self) -> Set[str]:
"""
Returns all keys that have been unchanged.
If the keys are in child dictionaries they will be represented with
. notation
"""
def _unchanged(current_dict, diffs, prefix):
keys = []
for key in current_dict.keys():
if key not in diffs:
keys.append("{0}{1}".format(prefix, key))
elif isinstance(current_dict[key], dict):
if "new" in diffs[key]:
# There is a diff
continue
else:
keys.extend(
_unchanged(
current_dict[key],
diffs[key],
prefix="{0}{1}.".format(prefix, key),
)
)
return keys
return sorted(_unchanged(self.current_dict, self._diffs, prefix=""))
@property
def diffs(self) -> Dict:
"""Returns a dict with the recursive diffs current_dict - past_dict"""
return self._diffs
@property
def new_values(self) -> Dict:
"""Returns a dictionary with the new values"""
return self._get_values(self._diffs, type_="new")
@property
def old_values(self) -> Dict:
"""Returns a dictionary with the old values"""
return self._get_values(self._diffs, type_="old")
@property
def changes_str(self) -> str:
"""Returns a string describing the changes"""
return "\n".join(self._get_changes(self._diffs))
def deep_diff(old: Dict, new: Dict, ignore: List = None) -> Dict[str, Any]:
ignore = ignore or []
res = {}
old = copy.deepcopy(old) or {}
new = copy.deepcopy(new) or {}
stack = [(old, new, False)]
while len(stack) > 0:
tmps = []
tmp_old, tmp_new, reentrant = stack.pop()
for key in set(list(tmp_old) + list(tmp_new)):
if key in tmp_old and key in tmp_new and tmp_old[key] == tmp_new[key]:
del tmp_old[key]
del tmp_new[key]
continue
if not reentrant:
if key in tmp_old and key in ignore:
del tmp_old[key]
if key in tmp_new and key in ignore:
del tmp_new[key]
if isinstance(tmp_old.get(key), Mapping) and isinstance(
tmp_new.get(key), Mapping
):
tmps.append((tmp_old[key], tmp_new[key], False))
if tmps:
stack.extend([(tmp_old, tmp_new, True)] + tmps)
if old:
res["old"] = old
if new:
res["new"] = new
return res
def recursive_diff(
past_dict: Dict, current_dict: Dict, ignore_missing_keys: bool = True
) -> RecursiveDictDiffer:
"""
Returns a RecursiveDictDiffer object that computes the recursive diffs
between two dictionaries
past_dict
Past dictionary
current_dict
Current dictionary
ignore_missing_keys
Flag specifying whether to ignore keys that no longer exist in the
current_dict, but exist in the past_dict. If true, the diff will
not contain the missing keys.
Default is True.
"""
return RecursiveDictDiffer(past_dict, current_dict, ignore_missing_keys)

View File

@@ -0,0 +1,44 @@
from typing import Dict, List
def to_num(text: str) -> int or float:
"""
Convert a string to a number.
Returns an integer if the string represents an integer, a floating
point number if the string is a real number, or the string unchanged
otherwise.
"""
try:
return int(text)
except ValueError:
try:
return float(text)
except ValueError:
return text
def to_dict(data: List[str], key: str) -> Dict:
"""
Convert MySQL-style output to a python dictionary
"""
ret = {}
headers = [""]
for line in data:
if not line:
continue
if line.startswith("+"):
continue
comps = line.split("|")
for comp in range(len(comps)):
comps[comp] = comps[comp].strip()
if len(headers) > 1:
index = len(headers) - 1
row = {}
for field in range(index):
if field < 1:
continue
row[headers[field]] = to_num(comps[field])
ret[row[key]] = row
else:
headers = comps
return ret

View File

@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, unicode_literals
from typing import Any, Dict
import msgpack
import sys
def _trim_dict_in_dict(data: Dict, max_val_size: int, replace_with: Any):
"""
Takes a dictionary, max_val_size and replace_with
and recursively loops through and replaces any values
that are greater than max_val_size.
"""
for key in data:
if isinstance(data[key], dict):
_trim_dict_in_dict(data[key], max_val_size, replace_with)
else:
if sys.getsizeof(data[key]) > max_val_size:
data[key] = replace_with
def trim_dict(
data: Any,
max_dict_bytes: int,
percent: float = 50.0,
stepper_size: int = 10,
replace_with: str = "VALUE_TRIMMED",
is_msgpacked: bool = False,
use_bin_type: bool = False,
):
"""
Takes a dictionary and iterates over its keys, looking for
large values and replacing them with a trimmed string.
If after the first pass over dictionary keys, the dictionary
is not sufficiently small, the stepper_size will be increased
and the dictionary will be rescanned. This allows for progressive
scanning, removing large items first and only making additional
passes for smaller items if necessary.
This function uses msgpack to calculate the size of the dictionary
in question. While this might seem like unnecessary overhead, a
data structure in python must be serialized in order for sys.getsizeof()
to accurately return the items referenced in the structure.
Ex:
>>> dict_tools.trim.trim_dict({'a': 'b', 'c': 'x' * 10000}, 100)
{'a': 'b', 'c': 'VALUE_TRIMMED'}
To improve performance, it is adviseable to pass in msgpacked
data structures instead of raw dictionaries. If a msgpack
structure is passed in, it will not be unserialized unless
necessary.
If a msgpack is passed in, it will be repacked if necessary
before being returned.
:param use_bin_type: Set this to true if "is_msgpacked=True"
and the msgpack data has been encoded
with "use_bin_type=True". This also means
that the msgpack data should be decoded with
"encoding='utf-8'".
"""
if is_msgpacked:
dict_size = sys.getsizeof(data)
else:
dict_size = sys.getsizeof(msgpack.dumps(data))
if dict_size > max_dict_bytes:
if is_msgpacked:
if use_bin_type:
data = msgpack.loads(data, encoding="utf-8")
else:
data = msgpack.loads(data)
while True:
percent = float(percent)
max_val_size = float(max_dict_bytes * (percent / 100))
try:
for key in data:
if isinstance(data[key], dict):
_trim_dict_in_dict(data[key], max_val_size, replace_with)
else:
if sys.getsizeof(data[key]) > max_val_size:
data[key] = replace_with
percent = percent - stepper_size
max_val_size = float(max_dict_bytes * (percent / 100))
if use_bin_type:
dump_data = msgpack.dumps(data, use_bin_type=True)
else:
dump_data = msgpack.dumps(data)
cur_dict_size = sys.getsizeof(dump_data)
if cur_dict_size < max_dict_bytes:
if is_msgpacked: # Repack it
return dump_data
else:
return data
elif max_val_size == 0:
if is_msgpacked:
return dump_data
else:
return data
except ValueError:
pass
if is_msgpacked:
if use_bin_type:
return msgpack.dumps(data, use_bin_type=True)
else:
return msgpack.dumps(data)
else:
return data

View File

@@ -0,0 +1,269 @@
# -*- coding: utf-8 -*-
"""
Alex Martelli's solution for recursive dict update from
http://stackoverflow.com/a/3233356
"""
# Import 3rd-party libs
from . import data, yamlex, DEFAULT_TARGET_DELIM
import copy
import logging
from collections.abc import Mapping
# Default delimiter for multi-level traversal in targeting
from typing import Dict, Any, Tuple
log = logging.getLogger(__name__)
def update(
dest: Dict, upd: Dict, recursive_update: bool = True, merge_lists: bool = False
) -> Dict:
"""
Recursive version of the default dict.update
Merges upd recursively into dest
If recursive_update=False, will use the classic dict.update, or fall back
on a manual merge (helpful for non-dict types like FunctionWrapper)
If merge_lists=True, will aggregate list object types instead of replace.
The list in ``upd`` is added to the list in ``dest``, so the resulting list
is ``dest[key] + upd[key]``. This behavior is only activated when
recursive_update=True. By default merge_lists=False.
.. versionchanged: 2016.11.6
When merging lists, duplicate values are removed. Values already
present in the ``dest`` list are not added from the ``upd`` list.
"""
if (not isinstance(dest, Mapping)) or (not isinstance(upd, Mapping)):
raise TypeError("Cannot update using non-dict types in dictupdate.update()")
updkeys = list(upd.keys())
if not set(list(dest.keys())) & set(updkeys):
recursive_update = False
if recursive_update:
for key in updkeys:
val = upd[key]
try:
dest_subkey = dest.get(key, None)
except AttributeError:
dest_subkey = None
if isinstance(dest_subkey, Mapping) and isinstance(val, Mapping):
ret = update(dest_subkey, val, merge_lists=merge_lists)
dest[key] = ret
elif isinstance(dest_subkey, list) and isinstance(val, list):
if merge_lists:
merged = copy.deepcopy(dest_subkey)
merged.extend([x for x in val if x not in merged])
dest[key] = merged
else:
dest[key] = upd[key]
else:
dest[key] = upd[key]
return dest
try:
for k in upd:
dest[k] = upd[k]
except AttributeError:
# this mapping is not a dict
for k in upd:
dest[k] = upd[k]
return dest
def merge_list(obj_a: Dict, obj_b: Dict) -> Dict:
ret = {}
for key, val in obj_a.items():
if key in obj_b:
ret[key] = [val, obj_b[key]]
else:
ret[key] = val
return ret
def merge_recurse(obj_a: Dict, obj_b: Dict, merge_lists: bool = False) -> Dict:
copied = copy.deepcopy(obj_a)
return update(copied, obj_b, merge_lists=merge_lists)
def merge_aggregate(obj_a, obj_b):
return yamlex.merge_recursive(obj_a, obj_b, level=1)
def merge_overwrite(obj_a, obj_b, merge_lists: bool = False):
for obj in obj_b:
if obj in obj_a:
obj_a[obj] = obj_b[obj]
return merge_recurse(obj_a, obj_b, merge_lists=merge_lists)
def merge(
obj_a,
obj_b,
strategy: str = "smart",
renderer: str = "yaml",
merge_lists: bool = False,
):
if strategy == "smart":
if renderer.split("|")[-1] == "yamlex" or renderer.startswith("yamlex_"):
strategy = "aggregate"
else:
strategy = "recurse"
if strategy == "list":
merged = merge_list(obj_a, obj_b)
elif strategy == "recurse":
merged = merge_recurse(obj_a, obj_b, merge_lists)
elif strategy == "aggregate":
#: level = 1 merge at least root data
merged = merge_aggregate(obj_a, obj_b)
elif strategy == "overwrite":
merged = merge_overwrite(obj_a, obj_b, merge_lists)
elif strategy == "none":
# If we do not want to merge, there is only one pillar passed, so we can safely use the default recurse,
# we just do not want to log an error
merged = merge_recurse(obj_a, obj_b)
else:
log.warning("Unknown merging strategy '%s', fallback to recurse", strategy)
merged = merge_recurse(obj_a, obj_b)
return merged
def ensure_dict_key(
in_dict: Dict, keys: str, delimiter: str = DEFAULT_TARGET_DELIM
) -> Dict:
"""
Ensures that in_dict contains the series of recursive keys defined in keys.
:param dict in_dict: The dict to work with.
:param str keys: The delimited string with one or more keys.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: dict
:return: Returns the modified in-place `in_dict`.
"""
if delimiter in keys:
a_keys = keys.split(delimiter)
else:
a_keys = [keys]
dict_pointer = in_dict
while a_keys:
current_key = a_keys.pop(0)
if current_key not in dict_pointer or not isinstance(
dict_pointer[current_key], dict
):
dict_pointer[current_key] = {}
dict_pointer = dict_pointer[current_key]
return in_dict
def _dict_rpartition(
in_dict: Dict, keys: str, delimiter: str = DEFAULT_TARGET_DELIM
) -> Tuple[Dict, str]:
"""
Helper function to:
- Ensure all but the last key in `keys` exist recursively in `in_dict`.
- Return the dict at the one-to-last key, and the last key
:param dict in_dict: The dict to work with.
:param str keys: The delimited string with one or more keys.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: tuple(dict, str)
:return: (The dict at the one-to-last key, the last key)
"""
if delimiter in keys:
all_but_last_keys, _, last_key = keys.rpartition(delimiter)
ensure_dict_key(in_dict, all_but_last_keys, delimiter=delimiter)
dict_pointer = data.traverse_dict(
in_dict, all_but_last_keys, default=None, delimiter=delimiter
)
else:
dict_pointer = in_dict
last_key = keys
return dict_pointer, last_key
def set_dict_key_value(
in_dict: Dict, keys: str, value: Any, delimiter: str = DEFAULT_TARGET_DELIM
) -> Dict:
"""
Ensures that in_dict contains the series of recursive keys defined in keys.
Also sets whatever is at the end of `in_dict` traversed with `keys` to `value`.
:param dict in_dict: The dictionary to work with
:param str keys: The delimited string with one or more keys.
:param any value: The value to assign to the nested dict-key.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: dict
:return: Returns the modified in-place `in_dict`.
"""
dict_pointer, last_key = _dict_rpartition(in_dict, keys, delimiter=delimiter)
dict_pointer[last_key] = value
return in_dict
def update_dict_key_value(
in_dict: Dict, keys: str, value: Any, delimiter: str = DEFAULT_TARGET_DELIM
) -> Dict:
"""
Ensures that in_dict contains the series of recursive keys defined in keys.
Also updates the dict, that is at the end of `in_dict` traversed with `keys`,
with `value`.
:param dict in_dict: The dictionary to work with
:param str keys: The delimited string with one or more keys.
:param any value: The value to update the nested dict-key with.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: dict
:return: Returns the modified in-place `in_dict`.
"""
dict_pointer, last_key = _dict_rpartition(in_dict, keys, delimiter=delimiter)
if last_key not in dict_pointer or dict_pointer[last_key] is None:
dict_pointer[last_key] = {}
dict_pointer[last_key].update(value)
return in_dict
def append_dict_key_value(
in_dict: Dict, keys: str, value: Any, delimiter: str = DEFAULT_TARGET_DELIM
) -> Dict:
"""
Ensures that in_dict contains the series of recursive keys defined in keys.
Also appends `value` to the list that is at the end of `in_dict` traversed
with `keys`.
:param dict in_dict: The dictionary to work with
:param str keys: The delimited string with one or more keys.
:param any value: The value to append to the nested dict-key.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: dict
:return: Returns the modified in-place `in_dict`.
"""
dict_pointer, last_key = _dict_rpartition(in_dict, keys, delimiter=delimiter)
if last_key not in dict_pointer or dict_pointer[last_key] is None:
dict_pointer[last_key] = []
dict_pointer[last_key].append(value)
return in_dict
def extend_dict_key_value(
in_dict: Dict, keys: str, value: Any, delimiter: str = DEFAULT_TARGET_DELIM
) -> Dict:
"""
Ensures that in_dict contains the series of recursive keys defined in keys.
Also extends the list, that is at the end of `in_dict` traversed with `keys`,
with `value`.
:param dict in_dict: The dictionary to work with
:param str keys: The delimited string with one or more keys.
:param any value: The value to extend the nested dict-key with.
:param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
:rtype: dict
:return: Returns the modified in-place `in_dict`.
"""
dict_pointer, last_key = _dict_rpartition(in_dict, keys, delimiter=delimiter)
if last_key not in dict_pointer or dict_pointer[last_key] is None:
dict_pointer[last_key] = []
dict_pointer[last_key].extend(value)
return in_dict

View File

@@ -0,0 +1,22 @@
from typing import ByteString, Dict, Iterable, Mapping, Text
def decode_dict(data: Dict[bytes, bytes]) -> Dict[str, str]:
"""
Recursively decode all byte-strings found in a dictionary
"""
ret = {}
for key, value in data.items():
if isinstance(key, ByteString):
key = key.decode()
if isinstance(value, (Mapping, Dict)):
ret[key] = decode_dict(value)
elif isinstance(value, ByteString):
ret[key] = value.decode()
elif isinstance(value, Iterable) and not isinstance(value, Text):
ret[key] = value.__new__(
x.decode() if isinstance(x, ByteString) else x for x in value
)
else:
ret[key] = value
return ret

View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
"""
Various XML utilities
"""
from typing import Dict
from xml.etree import ElementTree
def _conv_name(x: str) -> str:
"""
If this XML tree has an xmlns attribute, then etree will add it
to the beginning of the tag, like: "{http://path}tag".
"""
if "}" in x:
comps = x.split("}")
name = comps[1]
return name
return x
def _to_dict(xmltree: ElementTree) -> Dict:
"""
Converts an XML ElementTree to a dictionary that only contains items.
This is the default behavior in version 2017.7. This will default to prevent
unexpected parsing issues on modules dependent on this.
"""
# If this object has no children, the for..loop below will return nothing
# for it, so just return a single dict representing it.
if not xmltree:
name = _conv_name(xmltree.tag)
return {name: xmltree.text}
xmldict = {}
for item in xmltree:
name = _conv_name(item.tag)
if name not in xmldict:
if item:
xmldict[name] = _to_dict(item)
else:
xmldict[name] = item.text
else:
# If a tag appears more than once in the same place, convert it to
# a list. This may require that the caller watch for such a thing
# to happen, and behave accordingly.
if not isinstance(xmldict[name], list):
xmldict[name] = [xmldict[name]]
xmldict[name].append(_to_dict(item))
return xmldict
def _to_full_dict(xmltree: ElementTree):
"""
Returns the full XML dictionary including attributes.
"""
xmldict = {}
for attrName, attrValue in xmltree.attrib.items():
xmldict[attrName] = attrValue
if not xmltree:
if not xmldict:
# If we don't have attributes, we should return the value as a string
# ex: <entry>test</entry>
return xmltree.text
elif xmltree.text:
# XML allows for empty sets with attributes, so we need to make sure that capture this.
# ex: <entry name="test"/>
xmldict[_conv_name(xmltree.tag)] = xmltree.text
for item in xmltree:
name = _conv_name(item.tag)
if name not in xmldict:
xmldict[name] = _to_full_dict(item)
else:
# If a tag appears more than once in the same place, convert it to
# a list. This may require that the caller watch for such a thing
# to happen, and behave accordingly.
if not isinstance(xmldict[name], list):
xmldict[name] = [xmldict[name]]
xmldict[name].append(_to_full_dict(item))
return xmldict
def to_dict(xmltree: ElementTree, attr: bool = False):
"""
Convert an XML tree into a dict. The tree that is passed in must be an
ElementTree object.
Args:
xmltree: An ElementTree object.
attr: If true, attributes will be parsed. If false, they will be ignored.
"""
if attr:
return _to_full_dict(xmltree)
else:
return _to_dict(xmltree)

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""
YAMLEX is a format that allows for things like sls files to be
more intuitive.
It's an extension of YAML that implements all the salt magic:
- it implies omap for any dict like.
- it implies that string like data are str, not unicode
- ...
For example, the file `states.sls` has this contents:
.. code-block:: yaml
foo:
bar: 42
baz: [1, 2, 3]
The file can be parsed into Python like this
.. code-block:: python
from salt.serializers import yamlex
with open('state.sls', 'r') as stream:
obj = yamlex.deserialize(stream)
Check that ``obj`` is an OrderedDict
.. code-block:: python
from salt.utils.odict import OrderedDict
assert isinstance(obj, dict)
assert isinstance(obj, OrderedDict)
yamlex `__repr__` and `__str__` objects' methods render YAML understandable
string. It means that they are template friendly.
.. code-block:: python
print '{0}'.format(obj)
returns:
::
{foo: {bar: 42, baz: [1, 2, 3]}}
and they are still valid YAML:
.. code-block:: python
from salt.serializers import yaml
yml_obj = yaml.deserialize(str(obj))
assert yml_obj == obj
yamlex implements also custom tags:
!aggregate
this tag allows structures aggregation.
For example:
.. code-block:: yaml
placeholder: !aggregate foo
placeholder: !aggregate bar
placeholder: !aggregate baz
is rendered as
.. code-block:: yaml
placeholder: [foo, bar, baz]
!reset
this tag flushes the computing value.
.. code-block:: yaml
placeholder: {!aggregate foo: {foo: 42}}
placeholder: {!aggregate foo: {bar: null}}
!reset placeholder: {!aggregate foo: {baz: inga}}
is roughly equivalent to
.. code-block:: yaml
placeholder: {!aggregate foo: {baz: inga}}
Document is defacto an aggregate mapping.
"""
# pylint: disable=invalid-name,no-member,missing-docstring,no-self-use
# pylint: disable=too-few-public-methods,too-many-public-methods
# Import python libs
import collections
import copy
import datetime
import logging
from typing import TextIO
# Import 3rd-party libs
import yaml
from yaml.constructor import ConstructorError
from yaml.nodes import MappingNode
from .aggregation import Map, Sequence, aggregate
__all__ = ["deserialize", "serialize", "available"]
log = logging.getLogger(__name__)
available = True
# prefer C bindings over python when available
BaseLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
# CSafeDumper causes repr errors in python3, so use the pure Python one
try:
# Depending on how PyYAML was built, yaml.SafeDumper may actually be
# yaml.cyaml.CSafeDumper (i.e. the C dumper instead of pure Python).
BaseDumper = yaml.dumper.SafeDumper
except AttributeError:
# Here just in case, but yaml.dumper.SafeDumper should always exist
BaseDumper = yaml.SafeDumper
ERROR_MAP = {
("found character '\\t' " "that cannot start any token"): "Illegal tab character"
}
def deserialize(stream_or_string: str or TextIO, **options):
"""
Deserialize any string of stream like object into a Python data structure.
:param stream_or_string: stream or string to deserialize.
:param options: options given to lower yaml module.
"""
options.setdefault("Loader", Loader)
return yaml.load(stream_or_string, **options)
def serialize(obj, **options):
"""
Serialize Python data to YAML.
:param obj: the data structure to serialize
:param options: options given to lower yaml module.
"""
options.setdefault("Dumper", Dumper)
options.setdefault("default_flow_style", None)
response = yaml.dump(obj, **options)
if response.endswith("\n...\n"):
return response[:-5]
if response.endswith("\n"):
return response[:-1]
return response
class Loader(BaseLoader): # pylint: disable=W0232
"""
Create a custom YAML loader that uses the custom constructor. This allows
for the YAML loading defaults to be manipulated based on needs within salt
to make things like sls file more intuitive.
"""
DEFAULT_SCALAR_TAG = "tag:yaml.org,2002:str"
DEFAULT_SEQUENCE_TAG = "tag:yaml.org,2002:seq"
DEFAULT_MAPPING_TAG = "tag:yaml.org,2002:omap"
def compose_document(self):
node = BaseLoader.compose_document(self)
node.tag = "!aggregate"
return node
def construct_yaml_omap(self, node):
"""
Build the SLSMap
"""
sls_map = SLSMap()
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found {0}".format(node.id),
node.start_mark,
)
self.flatten_mapping(node)
for key_node, value_node in node.value:
# !reset instruction applies on document only.
# It tells to reset previous decoded value for this present key.
reset = key_node.tag == "!reset"
# even if !aggregate tag apply only to values and not keys
# it's a reason to act as a such nazi.
if key_node.tag == "!aggregate":
log.warning("!aggregate applies on values only, not on keys")
value_node.tag = key_node.tag
key_node.tag = self.resolve_sls_tag(key_node)[0]
key = self.construct_object(key_node, deep=False)
try:
hash(key)
except TypeError:
err = (
"While constructing a mapping {0} found unacceptable " "key {1}"
).format(node.start_mark, key_node.start_mark)
raise ConstructorError(err)
value = self.construct_object(value_node, deep=False)
if key in sls_map and not reset:
value = merge_recursive(sls_map[key], value)
sls_map[key] = value
return sls_map
def construct_sls_str(self, node):
"""
Build the SLSString.
"""
# Ensure obj is str, not py2 unicode or py3 bytes
obj = self.construct_scalar(node)
return SLSString(obj)
def construct_sls_int(self, node):
"""
Verify integers and pass them in correctly is they are declared
as octal
"""
if node.value == "0":
pass
elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")):
node.value = node.value.lstrip("0")
# If value was all zeros, node.value would have been reduced to
# an empty string. Change it to '0'.
if node.value == "":
node.value = "0"
return int(node.value)
def construct_sls_aggregate(self, node):
try:
tag, deep = self.resolve_sls_tag(node)
except Exception: # pylint: disable=broad-except
raise ConstructorError("unable to build reset")
node = copy.copy(node)
node.tag = tag
obj = self.construct_object(node, deep)
if obj is None:
return AggregatedSequence()
elif tag == self.DEFAULT_MAPPING_TAG:
return AggregatedMap(obj)
elif tag == self.DEFAULT_SEQUENCE_TAG:
return AggregatedSequence(obj)
return AggregatedSequence([obj])
def construct_sls_reset(self, node):
try:
tag, deep = self.resolve_sls_tag(node)
except Exception: # pylint: disable=broad-except
raise ConstructorError("unable to build reset")
node = copy.copy(node)
node.tag = tag
return self.construct_object(node, deep)
def resolve_sls_tag(self, node):
if isinstance(node, yaml.nodes.ScalarNode):
# search implicit tag
tag = self.resolve(yaml.nodes.ScalarNode, node.value, [True, True])
deep = False
elif isinstance(node, yaml.nodes.SequenceNode):
tag = self.DEFAULT_SEQUENCE_TAG
deep = True
elif isinstance(node, yaml.nodes.MappingNode):
tag = self.DEFAULT_MAPPING_TAG
deep = True
else:
raise ConstructorError("unable to resolve tag")
return tag, deep
Loader.add_constructor("!aggregate", Loader.construct_sls_aggregate) # custom type
Loader.add_constructor("!reset", Loader.construct_sls_reset) # custom type
Loader.add_constructor(
"tag:yaml.org,2002:omap", Loader.construct_yaml_omap
) # our overwrite
Loader.add_constructor(
"tag:yaml.org,2002:str", Loader.construct_sls_str
) # our overwrite
Loader.add_constructor(
"tag:yaml.org,2002:int", Loader.construct_sls_int
) # our overwrite
Loader.add_multi_constructor("tag:yaml.org,2002:null", Loader.construct_yaml_null)
Loader.add_multi_constructor("tag:yaml.org,2002:bool", Loader.construct_yaml_bool)
Loader.add_multi_constructor("tag:yaml.org,2002:float", Loader.construct_yaml_float)
Loader.add_multi_constructor("tag:yaml.org,2002:binary", Loader.construct_yaml_binary)
Loader.add_multi_constructor(
"tag:yaml.org,2002:timestamp", Loader.construct_yaml_timestamp
)
Loader.add_multi_constructor("tag:yaml.org,2002:pairs", Loader.construct_yaml_pairs)
Loader.add_multi_constructor("tag:yaml.org,2002:set", Loader.construct_yaml_set)
Loader.add_multi_constructor("tag:yaml.org,2002:seq", Loader.construct_yaml_seq)
Loader.add_multi_constructor("tag:yaml.org,2002:map", Loader.construct_yaml_map)
class SLSMap(collections.OrderedDict):
"""
Ensures that dict str() and repr() are YAML friendly.
.. code-block:: python
>>> mapping = OrderedDict([('a', 'b'), ('c', None)])
>>> print mapping
OrderedDict([('a', 'b'), ('c', None)])
>>> sls_map = SLSMap(mapping)
>>> print sls_map.__str__()
{a: b, c: null}
"""
def __str__(self):
return serialize(self, default_flow_style=True)
def __repr__(self, _repr_running=None):
return serialize(self, default_flow_style=True)
class SLSString(str):
"""
Ensures that str str() and repr() are YAML friendly.
.. code-block:: python
>>> scalar = str('foo')
>>> print 'foo'
foo
>>> sls_scalar = SLSString(scalar)
>>> print sls_scalar
"foo"
"""
def __str__(self):
return serialize(self, default_style='"')
def __repr__(self):
return serialize(self, default_style='"')
class AggregatedMap(SLSMap, Map):
pass
class AggregatedSequence(Sequence):
pass
class Dumper(BaseDumper): # pylint: disable=W0232
"""
sls dumper.
"""
def represent_odict(self, data):
return self.represent_mapping("tag:yaml.org,2002:map", list(data.items()))
Dumper.add_multi_representer(type(None), Dumper.represent_none)
Dumper.add_multi_representer(bytes, Dumper.represent_binary)
Dumper.add_multi_representer(str, Dumper.represent_str)
Dumper.add_multi_representer(bool, Dumper.represent_bool)
Dumper.add_multi_representer(int, Dumper.represent_int)
Dumper.add_multi_representer(float, Dumper.represent_float)
Dumper.add_multi_representer(list, Dumper.represent_list)
Dumper.add_multi_representer(tuple, Dumper.represent_list)
Dumper.add_multi_representer(
dict, Dumper.represent_odict
) # make every dict like obj to be represented as a map
Dumper.add_multi_representer(set, Dumper.represent_set)
Dumper.add_multi_representer(datetime.date, Dumper.represent_date)
Dumper.add_multi_representer(datetime.datetime, Dumper.represent_datetime)
Dumper.add_multi_representer(None, Dumper.represent_undefined)
def merge_recursive(obj_a, obj_b, level: bool or int = False):
"""
Merge obj_b into obj_a.
"""
return aggregate(
obj_a, obj_b, level, map_class=AggregatedMap, sequence_class=AggregatedSequence
)