1414 Sequence ,
1515 cast ,
1616 final ,
17+ overload ,
1718)
1819import warnings
1920
101102 Categorical ,
102103 DataFrame ,
103104 Index ,
105+ MultiIndex ,
104106 Series ,
105107 )
106108 from pandas .core .arrays import (
@@ -1792,7 +1794,7 @@ def safe_sort(
17921794 na_sentinel : int = - 1 ,
17931795 assume_unique : bool = False ,
17941796 verify : bool = True ,
1795- ) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
1797+ ) -> np .ndarray | MultiIndex | tuple [np .ndarray | MultiIndex , np .ndarray ]:
17961798 """
17971799 Sort ``values`` and reorder corresponding ``codes``.
17981800
@@ -1821,7 +1823,7 @@ def safe_sort(
18211823
18221824 Returns
18231825 -------
1824- ordered : ndarray
1826+ ordered : ndarray or MultiIndex
18251827 Sorted ``values``
18261828 new_codes : ndarray
18271829 Reordered ``codes``; returned when ``codes`` is not None.
@@ -1839,6 +1841,8 @@ def safe_sort(
18391841 raise TypeError (
18401842 "Only list-like objects are allowed to be passed to safe_sort as values"
18411843 )
1844+ original_values = values
1845+ is_mi = isinstance (original_values , ABCMultiIndex )
18421846
18431847 if not isinstance (values , (np .ndarray , ABCExtensionArray )):
18441848 # don't convert to string types
@@ -1850,6 +1854,7 @@ def safe_sort(
18501854 values = np .asarray (values , dtype = dtype ) # type: ignore[arg-type]
18511855
18521856 sorter = None
1857+ ordered : np .ndarray | MultiIndex
18531858
18541859 if (
18551860 not is_extension_array_dtype (values )
@@ -1859,13 +1864,17 @@ def safe_sort(
18591864 else :
18601865 try :
18611866 sorter = values .argsort ()
1862- ordered = values .take (sorter )
1867+ if is_mi :
1868+ # Operate on original object instead of casted array (MultiIndex)
1869+ ordered = original_values .take (sorter )
1870+ else :
1871+ ordered = values .take (sorter )
18631872 except TypeError :
18641873 # Previous sorters failed or were not applicable, try `_sort_mixed`
18651874 # which would work, but which fails for special case of 1d arrays
18661875 # with tuples.
18671876 if values .size and isinstance (values [0 ], tuple ):
1868- ordered = _sort_tuples (values )
1877+ ordered = _sort_tuples (values , original_values )
18691878 else :
18701879 ordered = _sort_mixed (values )
18711880
@@ -1927,19 +1936,33 @@ def _sort_mixed(values) -> np.ndarray:
19271936 )
19281937
19291938
1930- def _sort_tuples (values : np .ndarray ) -> np .ndarray :
1939+ @overload
1940+ def _sort_tuples (values : np .ndarray , original_values : np .ndarray ) -> np .ndarray :
1941+ ...
1942+
1943+
1944+ @overload
1945+ def _sort_tuples (values : np .ndarray , original_values : MultiIndex ) -> MultiIndex :
1946+ ...
1947+
1948+
1949+ def _sort_tuples (
1950+ values : np .ndarray , original_values : np .ndarray | MultiIndex
1951+ ) -> np .ndarray | MultiIndex :
19311952 """
19321953 Convert array of tuples (1d) to array or array (2d).
19331954 We need to keep the columns separately as they contain different types and
19341955 nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19351956 column as types cannot be compared).
1957+ We have to apply the indexer to the original values to keep the dtypes in
1958+ case of MultiIndexes
19361959 """
19371960 from pandas .core .internals .construction import to_arrays
19381961 from pandas .core .sorting import lexsort_indexer
19391962
19401963 arrays , _ = to_arrays (values , None )
19411964 indexer = lexsort_indexer (arrays , orders = True )
1942- return values [indexer ]
1965+ return original_values [indexer ]
19431966
19441967
19451968def union_with_duplicates (lvals : ArrayLike , rvals : ArrayLike ) -> ArrayLike :
0 commit comments