-
Notifications
You must be signed in to change notification settings - Fork 139
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I'm trying to generate a udaf that returns multiple timestamps for each partition id.
To Reproduce
import datafusion as dfn
from datafusion import udf, udaf, Accumulator, col
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np
class ResampleAccumulator(Accumulator):
def __init__(self):
self._min = float('inf')
self._max = 0
# 10 Hz
self._timestep = 100 # ms
def update(self, array):
# Logic to update the sum and count from an input array
# In a real implementation, you would process the pyarrow array efficiently
print("Enter update")
local_min, local_max = pc.min_max(array).values()
local_min_ns = local_min.cast(pa.timestamp('ns')).value
local_max_ns = local_max.cast(pa.timestamp('ns')).value
self._min = min(local_min_ns, self._min)
self._max = max(local_max_ns, self._max)
print(f"update {self._min=}, {self._max=}")
def merge(self, states_array):
print("Enter merge")
# Is there a better way to do this with pc?
# or maybe just throw it into numpy
self._min = min(states_array[0][0].as_py(), self._min)
self._max = max(states_array[1][0].as_py(), self._max)
print(f"merge {self._min=}, {self._max=}")
def state(self):
print("Enter state")
# Return the current state as a list of scalars
return pa.array([self._min, self._max], type=pa.int64())
def evaluate(self):
print("Enter evaluate")
desired_timestamps = np.arange(np.datetime64(self._min, 'ns'), np.datetime64(self._max, 'ns'), np.timedelta64(self._timestep, "ms"))
print(f"{len(desired_timestamps)=}")
array_result = pa.array(desired_timestamps, type=pa.timestamp('ns'))
print(array_result)
return array_result
resample_udaf = udaf(ResampleAccumulator, [pa.timestamp('ns')], pa.list_(pa.timestamp('ns')), [pa.int64(), pa.int64()], volatility="stable")
ctx = dfn.SessionContext()
df = ctx.from_pydict({"id": [0,1], "time": [np.datetime64(0, 'ns'), np.datetime64(1_000_000_000, 'ns')]})
print(df)
result = df.aggregate(
"id",
[resample_udaf(col("time"))]
)
print(result.schema())
result.collect()Output
Traceback (most recent call last):
File "<path>/<file>.py", line 60, in <module>
result.collect()
File "<path>/.venv/lib/python3.12/site-packages/datafusion/dataframe.py", line 729, in collect
return self.df.collect()
^^^^^^^^^^^^^^^^^
Exception: DataFusion error: Execution("ArrowTypeError: object of type <class 'pyarrow.lib.TimestampArray'> cannot be converted to int")Expected behavior
This works or provides a clearer error.
Additional context
Fails on datafusion 51. If I return just a single timestamp and update the udaf call then this works.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working