diff --git a/.travis.yml b/.travis.yml index 0814116..ccbdc43 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ python: install: - pip install coverage - pip install --upgrade pytest pytest-benchmark + - pip install pytypes script: - | diff --git a/multipledispatch/conflict.py b/multipledispatch/conflict.py index 8120a0d..0416391 100644 --- a/multipledispatch/conflict.py +++ b/multipledispatch/conflict.py @@ -1,18 +1,28 @@ from .utils import _toposort, groupby +from pytypes import is_subtype, is_Union, get_Union_params + class AmbiguityWarning(Warning): pass +def safe_subtype(a, b): + """Union safe subclass""" + if is_Union(a): + return any(is_subtype(tp, b) for tp in get_Union_params(a)) + else: + return is_subtype(a, b) + + def supercedes(a, b): """ A is consistent and strictly more specific than B """ - return len(a) == len(b) and all(map(issubclass, a, b)) + return len(a) == len(b) and all(map(safe_subtype, a, b)) def consistent(a, b): """ It is possible for an argument list to satisfy both A and B """ return (len(a) == len(b) and - all(issubclass(aa, bb) or issubclass(bb, aa) + all(safe_subtype(aa, bb) or safe_subtype(bb, aa) for aa, bb in zip(a, b))) diff --git a/multipledispatch/dispatcher.py b/multipledispatch/dispatcher.py index 67b679c..f8dc6ab 100644 --- a/multipledispatch/dispatcher.py +++ b/multipledispatch/dispatcher.py @@ -1,9 +1,14 @@ from warnings import warn import inspect + +import copy + from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning from .utils import expand_tuples -import itertools as itl +import itertools as itl +import pytypes +import typing class MDNotImplementedError(NotImplementedError): @@ -46,6 +51,7 @@ def restart_ordering(on_ambiguity=ambiguity_warn): DeprecationWarning, ) + class Dispatcher(object): """ Dispatch methods based on type signature @@ -140,13 +146,17 @@ def add(self, signature, func): >>> D = Dispatcher('add') >>> D.add((int, int), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y) + >>> D.add((typing.Optional[str], ), lambda x: x) >>> D(1, 2) 3 - >>> D(1, 2.0) + >>> D('1', 2.0) Traceback (most recent call last): ... - NotImplementedError: Could not find signature for add: + NotImplementedError: Could not find signature for add: + >>> D('s') + 's' + >>> D(None) When ``add`` detects a warning it calls the ``on_ambiguity`` callback with a dispatcher/itself, and a set of ambiguous type signature pairs @@ -157,24 +167,35 @@ def add(self, signature, func): annotations = self.get_func_annotations(func) if annotations: signature = annotations + # Make function annotation dict + + def process_union(tp): + if isinstance(tp, tuple): + t = typing.Union[tuple(process_union(e) for e in tp)] + return t + else: + return tp - # Handle union types - if any(isinstance(typ, tuple) for typ in signature): - for typs in expand_tuples(signature): - self.add(typs, func) - return + signatures = expand_tuples(signature) + for signature in signatures: + signature = tuple(process_union(tp) for tp in signature) - for typ in signature: - if not isinstance(typ, type): - str_sig = ', '.join(c.__name__ if isinstance(c, type) - else str(c) for c in signature) - raise TypeError("Tried to dispatch on non-type: %s\n" - "In signature: <%s>\n" - "In function: %s" % - (typ, str_sig, self.name)) + # make a copy of the function (if needed) and apply the function annotations - self.funcs[signature] = func - self._cache.clear() + # TODO: MAKE THIS type or typevar + for typ in signature: + try: + typing.Union[typ] + except TypeError: + str_sig = ', '.join(c.__name__ if isinstance(c, type) + else str(c) for c in signature) + raise TypeError("Tried to dispatch on non-type: %s\n" + "In signature: <%s>\n" + "In function: %s" % + (typ, str_sig, self.name)) + + self.funcs[signature] = func + self._cache.clear() try: del self._ordering @@ -196,7 +217,11 @@ def reorder(self, on_ambiguity=ambiguity_warn): return od def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) + try: + types = tuple([pytypes.deep_type(arg, 1, max_sample=10) for arg in args]) + except: + # some things dont deeptype welkl + types = tuple([type(arg) for arg in args]) try: func = self._cache[types] except KeyError: @@ -259,12 +284,26 @@ def dispatch(self, *types): except StopIteration: return None + @staticmethod + def get_type_vars(x): + if isinstance(x, typing.TypeVar): + yield x + if isinstance(x, typing.GenericMeta): + for e in x.__parameters__: + yield e + def dispatch_iter(self, *types): n = len(types) for signature in self.ordering: - if len(signature) == n and all(map(issubclass, types, signature)): + if len(signature) == n: result = self.funcs[signature] - yield result + try: + typsig = typing.Tuple[signature] + typvars = list(self.get_type_vars(typsig)) + if pytypes.is_subtype(typing.Tuple[types], typsig, bound_typevars={t.__name__: t for t in typvars}): + yield result + except pytypes.InputTypeError: + continue def resolve(self, types): """ Deterimine appropriate implementation for this type signature diff --git a/multipledispatch/tests/test_core.py b/multipledispatch/tests/test_core.py index d3f6eec..763df69 100644 --- a/multipledispatch/tests/test_core.py +++ b/multipledispatch/tests/test_core.py @@ -131,11 +131,11 @@ def f(x): def test_union_types(): @dispatch((A, C)) - def f(x): + def hh(x): return 1 - assert f(A()) == 1 - assert f(C()) == 1 + assert hh(A()) == 1 + assert hh(C()) == 1 def test_namespaces(): diff --git a/multipledispatch/tests/test_dispatcher_3only.py b/multipledispatch/tests/test_dispatcher_3only.py index b041450..c57259d 100644 --- a/multipledispatch/tests/test_dispatcher_3only.py +++ b/multipledispatch/tests/test_dispatcher_3only.py @@ -4,6 +4,8 @@ from multipledispatch import dispatch from multipledispatch.dispatcher import Dispatcher +from multipledispatch.utils import raises +import typing def test_function_annotation_register(): @@ -30,8 +32,23 @@ def inc(x: int): def inc(x: float): return x - 1 + @dispatch() + def inc(x: typing.Optional[str]): + return x + + @dispatch() + def inc(x: typing.List[int]): + return x[0] * 4 + + @dispatch() + def inc(x: typing.List[str]): + return x[0] + 'b' + assert inc(1) == 2 assert inc(1.0) == 0.0 + assert inc('a') == 'a' + assert inc([8]) == 32 + assert inc(['a']) == 'ab' def test_function_annotation_dispatch_custom_namespace(): @@ -68,6 +85,18 @@ def f(self, x: float): assert foo.f(1.0) == 0.0 +def test_diagonal_dispatch(): + T = typing.TypeVar('T') + U = typing.TypeVar('U') + + @dispatch() + def diag(x: T, y: T): + return 'same' + + assert diag(1, 6) == 'same' + assert raises(NotImplementedError, lambda: diag(1, '1')) + + def test_overlaps(): @dispatch(int) def inc(x: int): diff --git a/multipledispatch/utils.py b/multipledispatch/utils.py index 4f49a10..8701756 100644 --- a/multipledispatch/utils.py +++ b/multipledispatch/utils.py @@ -1,3 +1,8 @@ + +import pytypes +import typing + + def raises(err, lamda): try: lamda() @@ -14,15 +19,25 @@ def expand_tuples(L): >>> expand_tuples([1, 2]) [(1, 2)] + + >>> expand_tuples([1, typing.Optional[str]]) #doctest: +ELLIPSIS + [(1, <... 'str'>), (1, <... 'NoneType'>)] """ if not L: return [()] - elif not isinstance(L[0], tuple): - rest = expand_tuples(L[1:]) - return [(L[0],) + t for t in rest] else: - rest = expand_tuples(L[1:]) - return [(item,) + t for t in rest for item in L[0]] + if pytypes.is_Union(L[0]): + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in pytypes.get_Union_params(L[0])] + elif not pytypes.is_of_type(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] # Taken from theano/theano/gof/sched.py