Skip to content
Open

Typing #1792

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/extending/ctype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,10 @@ Final version

class Double(Type):

def filter(self, x, strict=False, allow_downcast=None):
if strict and not isinstance(x, float):
def filter(self, data, strict=False, allow_downcast=None):
if strict and not isinstance(data, float):
raise TypeError('Expected a float!')
return float(x)
return float(data)

def values_eq_approx(self, x, y, tolerance=1e-4):
return abs(x - y) / (x + y) < tolerance
Expand Down
2 changes: 1 addition & 1 deletion doc/extending/type.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ required methods of the interface, except ``filter``.

class DoubleType(Type):

def filter(self, x, strict=False, allow_downcast=None):
def filter(self, data, strict=False, allow_downcast=None):
# See code above.
...

Expand Down
24 changes: 13 additions & 11 deletions pytensor/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.type import TensorType


class PdbBreakpoint(Op):
Expand Down Expand Up @@ -50,23 +51,26 @@ class PdbBreakpoint(Op):
# as the individual error values
breakpointOp = PdbBreakpoint("MSE too high")
condition = pt.gt(mse.sum(), 100)
mse, monitored_input, monitored_target = breakpointOp(condition, mse,
input, target)
mse, monitored_input, monitored_target = breakpointOp(
condition, mse, input, target
)

# Compile the pytensor function
fct = pytensor.function([input, target], mse)

# Use the function
print fct([10, 0], [10, 5]) # Will NOT activate the breakpoint
print fct([0, 0], [10, 5]) # Will activate the breakpoint
print(fct([10, 0], [10, 5])) # Will NOT activate the breakpoint
print(fct([0, 0], [10, 5])) # Will activate the breakpoint


"""

__props__ = ("name",)

def __init__(self, name):
def __init__(self, name: str):
self.name = name
self.view_map = {}
self.inp_types: list[TensorType] = []

def make_node(self, condition, *monitored_vars):
# Ensure that condition is an PyTensor tensor
Expand All @@ -83,13 +87,11 @@ def make_node(self, condition, *monitored_vars):
# (view_map and var_types) in that instance and then apply it on the
# inputs.
new_op = PdbBreakpoint(name=self.name)
new_op.view_map = {}
new_op.inp_types = []
for i in range(len(monitored_vars)):
for i, var in enumerate(monitored_vars):
# Every output i is a view of the input i+1 because of the input
# condition.
new_op.view_map[i] = [i + 1]
new_op.inp_types.append(monitored_vars[i].type)
new_op.inp_types.append(var.type)

# Build the Apply node
inputs = [condition, *monitored_vars]
Expand Down Expand Up @@ -141,8 +143,8 @@ def perform(self, node, inputs, output_storage):
for i in range(len(output_storage)):
output_storage[i][0] = inputs[i + 1]

def grad(self, inputs, output_gradients):
return [DisconnectedType()(), *output_gradients]
def grad(self, inputs, output_grads):
return [DisconnectedType()(), *output_grads]

def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
Expand Down
22 changes: 11 additions & 11 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Define new Ops from existing Ops"""

import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable
from copy import copy
from functools import partial
from itertools import chain
from typing import Union, cast
from typing import Union

from pytensor.compile.function import function
from pytensor.compile.function.pfunc import rebuild_collect_shared
Expand Down Expand Up @@ -88,12 +88,12 @@ def local_traverse(out):


def construct_nominal_fgraph(
inputs: Sequence[Variable], outputs: Sequence[Variable]
inputs: list[Variable], outputs: list[Variable]
) -> tuple[
FunctionGraph,
Sequence[Variable],
dict[Variable, Variable],
dict[Variable, Variable],
list[SharedVariable],
dict[SharedVariable, Variable],
list[Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
implicit_shared_inputs = []
Expand All @@ -119,7 +119,7 @@ def construct_nominal_fgraph(
)

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
outputs,
inputs=inputs + implicit_shared_inputs,
replace=replacements,
copy_inputs_over=False,
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(
self.output_types = [out.type for out in outputs]

for override in (lop_overrides, grad_overrides, rop_overrides):
if override == "default":
if override == "default": # type: ignore[comparison-overlap]
raise ValueError(
"'default' is no longer a valid value for overrides. Use None instead."
)
Expand Down Expand Up @@ -702,7 +702,7 @@ def _build_and_cache_rop_op(self):
# Return a wrapper that combines connected and disconnected output gradients
def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
connected_output_grads = iter(rop_op(*inputs, **kwargs))
all_output_grads = []
all_output_grads: list[Variable | None] = []
for out_grad in output_grads:
if isinstance(out_grad.type, DisconnectedType):
# R_Op does not have DisconnectedType yet, None should be used instead
Expand Down Expand Up @@ -875,8 +875,8 @@ def clone(self):
res.fgraph = res.fgraph.clone()
return res

def perform(self, node, inputs, outputs):
def perform(self, node, inputs, output_storage):
variables = self.fn(*inputs)
# zip strict not specified because we are in a hot loop
for output, variable in zip(outputs, variables):
for output, variable in zip(output_storage, variables):
output[0] = variable
8 changes: 4 additions & 4 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.destroyhandler import DestroyHandler
from pytensor.graph.features import AlreadyThere, BadOptimization
from pytensor.graph.features import AlreadyThere
from pytensor.graph.features import BadOptimization as _BadOptimization
from pytensor.graph.fg import Output
from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.traversal import io_toposort
Expand Down Expand Up @@ -144,7 +145,7 @@ def str_diagnostic(self):
return ret


class BadOptimization(DebugModeError, BadOptimization):
class BadOptimization(DebugModeError, _BadOptimization):
pass


Expand Down Expand Up @@ -2244,8 +2245,7 @@ class DebugMode(Mode):

"""

check_preallocated_output = config.DebugMode__check_preallocated_output
check_preallocated_output = check_preallocated_output.split(":")
check_preallocated_output = config.DebugMode__check_preallocated_output.split(":")
"""
List of strings representing ways to pre-allocate output memory in
tests. Valid values are: "previous" (previously-returned memory),
Expand Down
18 changes: 9 additions & 9 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def rebuild_collect_shared(

# This function implements similar functionality as graph.clone
# and it should be merged with that
clone_d = {}
clone_d: dict = {}
update_d = {}
update_expr = []
# list of shared inputs that are used as inputs of the graph
Expand Down Expand Up @@ -300,32 +300,32 @@ def clone_inputs(i):
update_expr.append((store_into, update_val))

# Elements of "outputs" are here cloned to "cloned_outputs"
cloned_outputs: list[Variable] | Variable | Out | list[Out]
if isinstance(outputs, list):
cloned_outputs = []
cloned_outputs_list = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can cloned_outputs just be used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't be type-stable and mypy doesn't like it. The problem is that if you declare it as list[T] | T it won't let you redeclare it as list[T].

for v in outputs:
if isinstance(v, Variable):
cloned_v = clone_v_get_shared_updates(v, copy_inputs_over)
cloned_outputs.append(cloned_v)
cloned_outputs_list.append(cloned_v)
elif isinstance(v, Out):
cloned_v = clone_v_get_shared_updates(v.variable, copy_inputs_over)
cloned_outputs.append(Out(cloned_v, borrow=v.borrow))
cloned_o = clone_v_get_shared_updates(v.variable, copy_inputs_over)
cloned_outputs_list.append(Out(cloned_o, borrow=v.borrow))
else:
raise TypeError(
"Outputs must be pytensor Variable or "
"Out instances. Received " + str(v) + " of type " + str(type(v))
)
# computed_list.append(cloned_v)
cloned_outputs = cloned_outputs_list
else:
if isinstance(outputs, Variable):
cloned_v = clone_v_get_shared_updates(outputs, copy_inputs_over)
cloned_outputs = cloned_v
# computed_list.append(cloned_v)
elif isinstance(outputs, Out):
cloned_v = clone_v_get_shared_updates(outputs.variable, copy_inputs_over)
cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
cloned_o = clone_v_get_shared_updates(outputs.variable, copy_inputs_over)
cloned_outputs = Out(cloned_o, borrow=outputs.borrow)
# computed_list.append(cloned_v)
elif outputs is None:
cloned_outputs = [] # TODO: get Function.__call__ to return None
else:
raise TypeError(
"output must be an PyTensor Variable or Out instance (or list of them)",
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,7 +1637,7 @@ def __init__(
if any(self.refeed):
warnings.warn("Inputs with default values are deprecated.", FutureWarning)

def create(self, input_storage=None, storage_map=None):
def create(self, input_storage=None, storage_map=None) -> Function:
"""
Create a function.

Expand Down Expand Up @@ -1730,7 +1730,7 @@ def create(self, input_storage=None, storage_map=None):
import_time = pytensor.link.c.cmodule.import_time - start_import_time
self.profile.import_time += import_time

fn = self.function_builder(
fn: Function = self.function_builder(
_fn,
_i,
_o,
Expand Down
3 changes: 2 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
self,
linker: str | Linker | None = None,
optimizer: str | RewriteDatabaseQuery = "default",
db: RewriteDatabase = None,
db: RewriteDatabase | None = None,
):
if linker is None:
linker = config.linker
Expand All @@ -317,6 +317,7 @@ def __init__(

self.__setstate__((linker, optimizer))

self.optdb: RewriteDatabase
if db is None:
global optdb
self.optdb = optdb
Expand Down
34 changes: 17 additions & 17 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ class TypeCastingOp(COp):
__props__: tuple = ()
_f16_ok: bool = True

def perform(self, node, inputs, outputs_storage):
outputs_storage[0][0] = inputs[0]
def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0]

def __str__(self):
return f"{self.__class__.__name__}"

def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
def c_code(self, node, name, inputs, outputs, sub):
(iname,) = inputs
(oname,) = outputs
fail = sub["fail"]

itype = node.inputs[0].type.__class__
Expand Down Expand Up @@ -92,8 +92,8 @@ def make_node(self, x):
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes

def grad(self, args, g_outs):
return g_outs
def grad(self, inputs, output_grads):
return output_grads


view_op = ViewOp()
Expand Down Expand Up @@ -160,15 +160,15 @@ def __init__(self):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, args, outs):
if hasattr(args[0], "copy"):
def perform(self, node, inputs, output_storage):
if hasattr(inputs[0], "copy"):
# when args[0] is a an ndarray of 0 dimensions,
# this return a numpy.dtype and not an ndarray
# So when the args have a copy attribute we use it
# as this don't have this problem
outs[0][0] = args[0].copy()
output_storage[0][0] = inputs[0].copy()
else:
outs[0][0] = copy.deepcopy(args[0])
output_storage[0][0] = copy.deepcopy(inputs[0])

def c_code_cache_version(self):
version = []
Expand All @@ -192,9 +192,9 @@ def c_code_cache_version(self):
version.append(1)
return tuple(version)

def c_code(self, node, name, inames, onames, sub):
(iname,) = inames
(oname,) = onames
def c_code(self, node, name, inputs, outputs, sub):
(iname,) = inputs
(oname,) = outputs
fail = sub["fail"]

itype = node.inputs[0].type.__class__
Expand Down Expand Up @@ -253,13 +253,13 @@ def __hash__(self):
def __str__(self):
return f"FromFunctionOp{{{self.__fn.__name__}}}"

def perform(self, node, inputs, outputs):
def perform(self, node, inputs, output_storage):
outs = self.__fn(*inputs)
if not isinstance(outs, list | tuple):
outs = (outs,)
assert len(outs) == len(outputs)
assert len(outs) == len(output_storage)
for i in range(len(outs)):
outputs[i][0] = outs[i]
output_storage[i][0] = outs[i]

def __reduce__(self):
mod = self.__fn.__module__
Expand Down
2 changes: 1 addition & 1 deletion pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _good_seem_param(seed):
return True
try:
int(seed)
except Exception:
except ValueError:
return False
return True

Expand Down
Loading
Loading