|
29 | 29 | is_categorical_dtype, |
30 | 30 | pandas_dtype, |
31 | 31 | ) |
32 | | -from pandas.core.dtypes.concat import union_categoricals |
33 | | -from pandas.core.dtypes.dtypes import ExtensionDtype |
| 32 | +from pandas.core.dtypes.concat import ( |
| 33 | + concat_compat, |
| 34 | + union_categoricals, |
| 35 | +) |
34 | 36 |
|
35 | 37 | from pandas.core.indexes.api import ensure_index_from_sequences |
36 | 38 |
|
@@ -378,43 +380,15 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict: |
378 | 380 | arrs = [chunk.pop(name) for chunk in chunks] |
379 | 381 | # Check each arr for consistent types. |
380 | 382 | dtypes = {a.dtype for a in arrs} |
381 | | - # TODO: shouldn't we exclude all EA dtypes here? |
382 | | - numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
383 | | - if len(numpy_dtypes) > 1: |
384 | | - # error: Argument 1 to "find_common_type" has incompatible type |
385 | | - # "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type, |
386 | | - # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, |
387 | | - # Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]" |
388 | | - common_type = np.find_common_type( |
389 | | - numpy_dtypes, # type: ignore[arg-type] |
390 | | - [], |
391 | | - ) |
392 | | - if common_type == np.dtype(object): |
393 | | - warning_columns.append(str(name)) |
| 383 | + non_cat_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
394 | 384 |
|
395 | 385 | dtype = dtypes.pop() |
396 | 386 | if is_categorical_dtype(dtype): |
397 | 387 | result[name] = union_categoricals(arrs, sort_categories=False) |
398 | 388 | else: |
399 | | - if isinstance(dtype, ExtensionDtype): |
400 | | - # TODO: concat_compat? |
401 | | - array_type = dtype.construct_array_type() |
402 | | - # error: Argument 1 to "_concat_same_type" of "ExtensionArray" |
403 | | - # has incompatible type "List[Union[ExtensionArray, ndarray]]"; |
404 | | - # expected "Sequence[ExtensionArray]" |
405 | | - result[name] = array_type._concat_same_type( |
406 | | - arrs # type: ignore[arg-type] |
407 | | - ) |
408 | | - else: |
409 | | - # error: Argument 1 to "concatenate" has incompatible |
410 | | - # type "List[Union[ExtensionArray, ndarray[Any, Any]]]" |
411 | | - # ; expected "Union[_SupportsArray[dtype[Any]], |
412 | | - # Sequence[_SupportsArray[dtype[Any]]], |
413 | | - # Sequence[Sequence[_SupportsArray[dtype[Any]]]], |
414 | | - # Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]] |
415 | | - # , Sequence[Sequence[Sequence[Sequence[ |
416 | | - # _SupportsArray[dtype[Any]]]]]]]" |
417 | | - result[name] = np.concatenate(arrs) # type: ignore[arg-type] |
| 389 | + result[name] = concat_compat(arrs) |
| 390 | + if len(non_cat_dtypes) > 1 and result[name].dtype == np.dtype(object): |
| 391 | + warning_columns.append(str(name)) |
418 | 392 |
|
419 | 393 | if warning_columns: |
420 | 394 | warning_names = ",".join(warning_columns) |
|
0 commit comments