Skip to content

SKLModelHandler drops extra keyword arguments in predict method #209

@robert-norberg

Description

@robert-norberg

When an sklearn-compatible estimator exposes a method with extra keyword arguments (e.g. predict(X, my_param=100)), the Snowflake Model Registry's SKLModelHandler wraps it in a closure that only accepts (self, X) and silently drops all kwargs. This means ParamSpec parameters registered in the model signature are never forwarded to the underlying estimator.

Minimal reproducible example:

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from snowflake.ml.model import model_signature
from snowflake.ml.model.model_signature import DataType, ParamSpec
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session

session = Session.builder.configs({<redacted>}).create()
registry = Registry(session=session)

class EstimatorWithKwarg(BaseEstimator):
    """Trivial estimator whose predict() accepts an extra `my_param` kwarg."""

    def fit(self, X, y=None):
        return self

    def predict(self, X, *, my_param=365):
        # Return `my_param` to verify it was forwarded
        return np.full(len(X), fill_value=my_param, dtype=float)

estimator = EstimatorWithKwarg().fit(pd.DataFrame({"a": [1, 2, 3]}))

# Verify the estimator itself works with my_param kwarg
X = pd.DataFrame({"a": [10, 20]})
y_hat = estimator.predict(X, my_param=180)
y_hat
array([180., 180.])
predict_sig = model_signature.infer_signature(
    X,
    y_hat,
    output_feature_names=["pred"],
    params = [ParamSpec(name="my_param", dtype=DataType.INT32, default_value=100)]
)

registry.log_model(
    model=estimator,
    model_name="deleteme",
    sample_input_data=X,
    signatures = {
        "predict": predict_sig
    }
)
model_ref = registry.get_model("deleteme")
mv = model_ref.version("default")
mv.run(X, function_name="predict", params={"my_param": 100})
SnowparkSQLException: (1304): 01c2d5a2-081b-3567-0026-55033fabff02: 100357 (P0000): Python Interpreter Error:

Traceback (most recent call last):

  File "/home/udf/10789520488747366/predict.py", line 91, in infer

    predictions_df = runner(input_df, **method_params)

                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

TypeError: SKLModelHandler.convert_as_custom_model.<locals>._create_custom_model.<locals>.fn_factory.<locals>.fn() got an unexpected keyword argument 'my_param' in function PREDICT with handler predict.infer

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions