Skip to content

Cannot do udaf that returns list of timestamps #1339

@ntjohnson1

Description

@ntjohnson1

Describe the bug
I'm trying to generate a udaf that returns multiple timestamps for each partition id.

To Reproduce

import datafusion as dfn
from datafusion import udf, udaf, Accumulator, col
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np

class ResampleAccumulator(Accumulator):
    def __init__(self):
        self._min = float('inf')
        self._max = 0
        # 10 Hz
        self._timestep = 100 # ms

    def update(self, array):
        # Logic to update the sum and count from an input array
        # In a real implementation, you would process the pyarrow array efficiently
        print("Enter update")
        local_min, local_max = pc.min_max(array).values()
        local_min_ns = local_min.cast(pa.timestamp('ns')).value
        local_max_ns = local_max.cast(pa.timestamp('ns')).value

        self._min = min(local_min_ns, self._min)
        self._max = max(local_max_ns, self._max)
        print(f"update {self._min=}, {self._max=}")

    def merge(self, states_array):
        print("Enter merge")
        # Is there a better way to do this with pc?
        # or maybe just throw it into numpy
        self._min = min(states_array[0][0].as_py(), self._min)
        self._max = max(states_array[1][0].as_py(), self._max)
        print(f"merge {self._min=}, {self._max=}")


    def state(self):
        print("Enter state")
        # Return the current state as a list of scalars
        return pa.array([self._min, self._max], type=pa.int64())

    def evaluate(self):
        print("Enter evaluate")
        desired_timestamps = np.arange(np.datetime64(self._min, 'ns'), np.datetime64(self._max, 'ns'), np.timedelta64(self._timestep, "ms"))
        print(f"{len(desired_timestamps)=}")
        array_result = pa.array(desired_timestamps, type=pa.timestamp('ns'))
        print(array_result)
        return array_result
resample_udaf = udaf(ResampleAccumulator, [pa.timestamp('ns')], pa.list_(pa.timestamp('ns')), [pa.int64(), pa.int64()], volatility="stable")

ctx = dfn.SessionContext()

df = ctx.from_pydict({"id": [0,1], "time": [np.datetime64(0, 'ns'), np.datetime64(1_000_000_000, 'ns')]})
print(df)

result = df.aggregate(
    "id",
    [resample_udaf(col("time"))]
)
print(result.schema())
result.collect()

Output

Traceback (most recent call last):
  File "<path>/<file>.py", line 60, in <module>
    result.collect()
  File "<path>/.venv/lib/python3.12/site-packages/datafusion/dataframe.py", line 729, in collect
    return self.df.collect()
           ^^^^^^^^^^^^^^^^^
Exception: DataFusion error: Execution("ArrowTypeError: object of type <class 'pyarrow.lib.TimestampArray'> cannot be converted to int")

Expected behavior
This works or provides a clearer error.

Additional context
Fails on datafusion 51. If I return just a single timestamp and update the udaf call then this works.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions