Skip to content

Commit 1fc9072

Browse files
EA: add back _from_scalar / cast_pointwise_result backwards compat (#63367)
1 parent 639ffc8 commit 1fc9072

File tree

6 files changed

+85
-18
lines changed

6 files changed

+85
-18
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def _hash_pandas_object(
208208
values, encoding=encoding, hash_key=hash_key, categorize=categorize
209209
)
210210

211+
def _cast_pointwise_result(self, values: ArrayLike) -> ArrayLike:
212+
values = np.asarray(values, dtype=object)
213+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
214+
211215
# Signature of "argmin" incompatible with supertype "ExtensionArray"
212216
def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
213217
# override base class by adding axis keyword

pandas/core/arrays/arrow/array.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
442442
# e.g. test_by_column_values_with_same_starting_value with nested
443443
# values, one entry of which is an ArrowStringArray
444444
# or test_agg_lambda_complex128_dtype_conversion for complex values
445-
return super()._cast_pointwise_result(values)
445+
values = np.asarray(values, dtype=object)
446+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
446447

447448
if pa.types.is_null(arr.type):
448449
if lib.infer_dtype(values) == "decimal":
@@ -498,7 +499,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
498499
if self.dtype.na_value is np.nan:
499500
# ArrowEA has different semantics, so we return numpy-based
500501
# result instead
501-
return super()._cast_pointwise_result(values)
502+
values = np.asarray(values, dtype=object)
503+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
502504
return ArrowExtensionArray(arr)
503505
return self._from_pyarrow_array(arr)
504506

pandas/core/arrays/base.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
cast,
2020
overload,
2121
)
22+
import warnings
2223

2324
import numpy as np
2425

@@ -33,6 +34,7 @@
3334
cache_readonly,
3435
set_module,
3536
)
37+
from pandas.util._exceptions import find_stack_level
3638
from pandas.util._validators import (
3739
validate_bool_kwarg,
3840
validate_insert_loc,
@@ -86,6 +88,7 @@
8688
AstypeArg,
8789
AxisInt,
8890
Dtype,
91+
DtypeObj,
8992
FillnaOptions,
9093
InterpolateOptions,
9194
NumpySorter,
@@ -383,13 +386,67 @@ def _from_factorized(cls, values, original):
383386
"""
384387
raise AbstractMethodError(cls)
385388

389+
@classmethod
390+
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
391+
"""
392+
Strict analogue to _from_sequence, allowing only sequences of scalars
393+
that should be specifically inferred to the given dtype.
394+
395+
Parameters
396+
----------
397+
scalars : sequence
398+
dtype : ExtensionDtype
399+
400+
Raises
401+
------
402+
TypeError or ValueError
403+
404+
Notes
405+
-----
406+
This is called in a try/except block when casting the result of a
407+
pointwise operation in ExtensionArray._cast_pointwise_result.
408+
"""
409+
try:
410+
return cls._from_sequence(scalars, dtype=dtype, copy=False)
411+
except (ValueError, TypeError):
412+
raise
413+
except Exception:
414+
warnings.warn(
415+
"_from_scalars should only raise ValueError or TypeError. "
416+
"Consider overriding _from_scalars where appropriate.",
417+
stacklevel=find_stack_level(),
418+
)
419+
raise
420+
386421
def _cast_pointwise_result(self, values) -> ArrayLike:
387422
"""
423+
Construct an ExtensionArray after a pointwise operation.
424+
388425
Cast the result of a pointwise operation (e.g. Series.map) to an
389-
array, preserve dtype_backend if possible.
426+
array. This is not required to return an ExtensionArray of the same
427+
type as self or of the same dtype. It can also return another
428+
ExtensionArray of the same "family" if you implement multiple
429+
ExtensionArrays/Dtypes that are interoperable (e.g. if you have float
430+
array with units, this method can return an int array with units).
431+
432+
If converting to your own ExtensionArray is not possible, this method
433+
falls back to returning an array with the default type inference.
434+
If you only need to cast to `self.dtype`, it is recommended to override
435+
`_from_scalars` instead of this method.
436+
437+
Parameters
438+
----------
439+
values : sequence
440+
441+
Returns
442+
-------
443+
ExtensionArray or ndarray
390444
"""
391-
values = np.asarray(values, dtype=object)
392-
return lib.maybe_convert_objects(values, convert_non_numeric=True)
445+
try:
446+
return type(self)._from_scalars(values, dtype=self.dtype)
447+
except (ValueError, TypeError):
448+
values = np.asarray(values, dtype=object)
449+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
393450

394451
# ------------------------------------------------------------------------
395452
# Must be a Sequence

pandas/core/arrays/sparse/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,8 @@ def _from_factorized(cls, values, original) -> Self:
622622
return cls(values, dtype=original.dtype)
623623

624624
def _cast_pointwise_result(self, values):
625-
result = super()._cast_pointwise_result(values)
625+
values = np.asarray(values, dtype=object)
626+
result = lib.maybe_convert_objects(values, convert_non_numeric=True)
626627
if result.dtype.kind == self.dtype.kind:
627628
try:
628629
# e.g. test_groupby_agg_extension

pandas/tests/extension/decimal/array.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ def _from_sequence_of_strings(cls, strings, *, dtype: ExtensionDtype, copy=False
111111
def _from_factorized(cls, values, original):
112112
return cls(values)
113113

114-
def _cast_pointwise_result(self, values):
115-
result = super()._cast_pointwise_result(values)
116-
try:
117-
# If this were ever made a non-test EA, special-casing could
118-
# be avoided by handling Decimal in maybe_convert_objects
119-
res = type(self)._from_sequence(result, dtype=self.dtype)
120-
except (ValueError, TypeError):
121-
return result
122-
return res
114+
# test to ensure that the base class _cast_pointwise_result works as expected
115+
# def _cast_pointwise_result(self, values):
116+
# try:
117+
# # If this were ever made a non-test EA, special-casing could
118+
# # be avoided by handling Decimal in maybe_convert_objects
119+
# res = type(self)._from_sequence(values, dtype=self.dtype)
120+
# except (ValueError, TypeError):
121+
# return values
122+
# return res
123123

124124
_HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)
125125

pandas/tests/extension/json/array.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,14 @@ def _from_factorized(cls, values, original):
9494
return cls([UserDict(x) for x in values if x != ()])
9595

9696
def _cast_pointwise_result(self, values):
97-
result = super()._cast_pointwise_result(values)
9897
try:
99-
return type(self)._from_sequence(result, dtype=self.dtype)
98+
return type(self)._from_sequence(values, dtype=self.dtype)
10099
except (ValueError, TypeError):
101-
return result
100+
# TODO replace with public function
101+
from pandas._libs import lib
102+
103+
values = np.asarray(values, dtype=object)
104+
return lib.maybe_convert_objects(values, convert_non_numeric=True)
102105

103106
def __getitem__(self, item):
104107
if isinstance(item, tuple):

0 commit comments

Comments
 (0)