diff --git a/engarde/checks.py b/engarde/checks.py index 8819d3c..ab9c461 100644 --- a/engarde/checks.py +++ b/engarde/checks.py @@ -8,6 +8,7 @@ - Makes its assert on the result - Return the original DataFrame """ + import numpy as np import pandas as pd import pandas.util.testing as tm @@ -197,21 +198,48 @@ def within_n_std(df, n=3): def has_dtypes(df, items): """ - Assert that a DataFrame has ``dtypes`` - + Assert that a DataFrame has ``dtypes`` as described in ``items``. + Parameters ========== df: DataFrame items: dict - mapping of columns to dtype. - + A mapping of column names to: + - dtypes, and/or + - functions (but **not** other callables!) that take a pandas.Series.dtype instance as input, and + return ``True`` if the ``dtype`` is of the correct dtype and ``False`` otherwise. + Returns ======= df : DataFrame + + Examples + ========= + + .. code:: python + + import numpy as np + import pandas as pd + import engarde.checks as ck + + df = pd.DataFrame({'A': np.random.randint(0, 10, 10), + 'B': np.random.randn(10)}) + df = df.pipe(ck.has_dtypes, items={'A': np.int32, + 'B': pd.api.types.is_float_dtype}) + """ + from types import FunctionType + from pandas.api.types import is_dtype_equal dtypes = df.dtypes for k, v in items.items(): - if not dtypes[k] == v: + if isinstance(v, FunctionType): + result = v(dtypes[k]) + if not isinstance(result, bool): + raise AssertionError("The function for key {!r}" + " must return a boolean, returned {!r}".format(k, type(result))) + if not result: + raise AssertionError("{} has the wrong dtype ({}) for function ({})".format(k, dtypes[k], v.__name__)) + elif not is_dtype_equal(dtypes[k], v): raise AssertionError("{} has the wrong dtype ({})".format(k, v)) return df diff --git a/tests/test_checks.py b/tests/test_checks.py index b7450c8..8f9a5c5 100644 --- a/tests/test_checks.py +++ b/tests/test_checks.py @@ -219,6 +219,26 @@ def test_has_dtypes(): with pytest.raises(AssertionError): dc.has_dtypes(items={'A': bool})(_noop)(df) +def test_has_dtypes_funcs(): + pat = pd.api.types + + df = pd.DataFrame({'A': np.random.randint(0, 10, 10), + 'B': np.random.randn(10), + 'C': list('abcdefghij'), + 'D': pd.Categorical(np.random.choice(['a', 'b'], 10))}) + dtypes = {'A': pat.is_integer_dtype, + 'B': pat.is_float_dtype, + 'C': pat.is_string_dtype, + 'D': pat.is_category_dtype} + tm.assert_frame_equal(df, ck.has_dtypes(df, dtypes)) + tm.assert_frame_equal(df, dc.has_dtypes(items=dtypes)(_noop)(df)) + + with pytest.raises(AssertionError): + ck.has_dtypes(df, {'A': pat.is_float_dtype}) + + with pytest.raises(AssertionError): + dc.has_dtypes(items={'A': pat.is_bool_dtype})(_noop)(df) + def test_one_to_many(): df = pd.DataFrame({ 'parameter': ['Cu', 'Cu', 'Pb', 'Pb'],