Skip to content

Releases: chaobrain/brainstate

Version 0.3.0

11 Mar 17:18

Choose a tag to compare

This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.

Breaking Changes

  • Python >= 3.11 required: Dropped support for Python 3.10. The requires-python field and classifiers now start at 3.11.
  • JAX >= 0.6.0 required: All dependency groups (cpu, cuda12, cuda13, tpu, testing) now mandate jax>=0.6.0.
  • Unified compilation cache in StatefulFunction: The four separate internal caches (_cached_jaxpr, _cached_out_shapes, _cached_jaxpr_out_tree, _cached_state_trace) have been consolidated into a single _compilation_cache storing _CachedCompilation objects. get_cache_stats() now returns {'compilation_cache': {...}} instead of four individual entries.
  • Immutable CacheKey replaces hashabledict: get_arg_cache_key() now returns a CacheKey (NamedTuple) instead of the mutable hashabledict. Code that directly inspected or constructed cache keys must be updated.
  • Removed internal _make_jaxpr function: The custom tracing implementation has been deleted in favor of using jax.make_jaxpr() directly (available in JAX >= 0.6.0).
  • Removed debug_depth and debug_context from GradientTransform: The depth and context parameters for NaN debugging no longer exist following the debug module rewrite.
  • Removed breakpoint_if function: The conditional breakpoint helper has been removed from brainstate.transform._debug.
  • Removed extend_axis_env_nd from compatible imports: This compatibility shim is no longer exported.

New Features

On-Device NaN/Inf Detection

  • Complete rewrite of the NaN debugging system (brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance.
  • Uses jax.debug.callback with thread-local storage to collect and report NaN findings.
  • Error tracebacks now point to the user's source code via source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations.
  • Recursive instrumentation of nested primitives (jit, cond, while, scan) for comprehensive NaN detection throughout the computation graph.
  • More compact and informative error messages via _format_nan_message().

JAX Traceback Filtering

  • Registered brainstate with JAX's traceback_util.register_exclusion() so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries.
  • Users can still see full tracebacks via JAX_TRACEBACK_FILTERING=off.

State Validation at Call Time

  • New _validate_state_shapes() method checks that current state shapes and dtypes match those recorded at compile time.
  • StatefulFunction.__call__() automatically validates before execution, catching state shape mismatches early with clear error messages.
  • Added static_argnums bounds validation — make_jaxpr() now raises ValueError if indices exceed the number of positional arguments.

New Compatible Import

  • Added mapped_aval import with version-based routing: jax.core.mapped_aval for JAX < 0.8.2, jax.extend.core.mapped_aval for >= 0.8.2.

Improvements

  • Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
  • Better cache key hashing: Dynamic args/kwargs are now flattened via jax.tree.flatten() before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g., Quantity).
  • Modern Python type annotations: Migrated from typing.Tuple, typing.List, typing.Dict, typing.Optional, typing.Union to built-in tuple, list, dict, X | None, X | Y syntax across the codebase.
  • IR visualization compatibility: Replaced direct jax.core.X references with compatible imports (Var, ClosedJaxpr, Jaxpr, JaxprEqn, Literal, DropVar) in the IR visualizer.
  • Deterministic error reporting: jax.debug.callback in _error_if.py now uses ordered=True for deterministic error callback ordering.
  • Graph operations cleanup: Major refactoring of _operation.py, _node.py, _convert.py, and _context.py with streamlined docstrings, better thread-safety documentation, and cleaner context managers.

Bug Fixes

  • Fixed Delay.__init__ initialization order: update_every is now initialized before register_entry is called, preventing attribute errors during entry registration (#135).
  • Fixed graph_to_tree private attribute access: Replaced internal _mapping access with public API usage in _convert.py.

Internal Changes

  • Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
  • Cleaned up TypeVar usage: removed unused C and Names aliases, renamed Node TypeVar to N, removed Hashable bound from type variables.
  • Removed unused tests (test_all_exports, test_function_imports_availability) from compatible import tests.
  • Rewrote debug and make_jaxpr test suites to match the new APIs.
  • IR optimization imports are now lazy-loaded inside make_jaxpr() only when ir_optimizations is configured.

CI/CD

  • Bumped actions/upload-artifact from v6 to v7.
  • Bumped actions/download-artifact from v7 to v8.

What's Changed

  • fix(nn): initialize update_every before register_entry by @Routhleck in #135
  • deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
  • deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
  • Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
  • Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
  • Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139

Full Changelog: v0.2.10...v0.3.0

Version 0.2.10

30 Jan 13:00
2019cae

Choose a tag to compare

This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.

New Features

NaN Debugging System

  • JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations

    • debug_nan: Analyze a function for NaN/Inf values with detailed reporting
    • debug_nan_if: Conditional NaN debugging with predicate-based activation
    • Full JIT compatibility for seamless integration into compiled workflows
    • Support for debugging NaN in while and scan primitives
    • Detailed analysis output including variable names, shapes, and affected indices
  • Gradient Function Integration: Added debug_nan parameter to gradient transformation functions

    • grad: Enable NaN debugging during gradient computation
    • vector_grad: NaN debugging for vectorized gradients
    • jacobian and jacobian_reverse: NaN debugging for Jacobian computations
    • hessian: NaN debugging for Hessian computations
  • Breakpoint Utility: New breakpoint function for conditional debugging

    • Wraps jax.debug.breakpoint with predicate support
    • Only triggers when the specified condition is True

API Changes

Module System

  • Renamed ModuleMapper to Map: Simplified naming for the vectorized module wrapper

    • Map provides vectorized (vmap2) and parallel (pmap2) mapping over modules
    • ModuleMapper retained as a deprecated alias for backward compatibility
    • Internal _ModuleMapperCalling renamed to _MapCaller for consistency
  • Enhanced Map.map() Method: Now accepts callable functions for flexible mapping operations

Bug Fixes

  • Fixed get_backend import for JAX version compatibility across different JAX releases
  • Removed abstractmethod decorators from Regularization class to allow proper instantiation
  • Cleaned up unused imports in module initialization files

Internal Changes

  • Added comprehensive test suite for NaN debugging (_debug_test.py, 938 lines)
  • Removed deprecated _mapping3.py module and associated tests
  • Streamlined module exports in __init__.py files

Version 0.2.9

16 Jan 14:16
c509d95

Choose a tag to compare

This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.

State Management

State Hook System

  • Global Hook Infrastructure: Comprehensive hook system for intercepting state operations

    • register_read_hook: Register hooks that execute when state values are read
    • register_write_hook: Register hooks that execute when state values are written
    • register_restore_hook: Register hooks that execute when state values are restored
    • HookManager: Thread-safe manager for organizing and executing hooks with priority support
    • HookContext: Context manager for scoped hook registration and execution
    • Enables advanced use cases: logging, debugging, value transformation, validation
  • Enhanced State Class: Improved state management with hook integration

    • Automatic hook execution on read/write operations
    • Better cache key handling for improved performance
    • Enhanced thread safety and context management
    • Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)

Neural Network Components

Parameter Management (brainstate.nn.Param and brainstate.nn.Const)

  • Renamed Classes: Simplified naming convention

    • ParaMParam: Trainable parameter wrapper
    • ConstMConst: Non-trainable constant wrapper
  • Enhanced Caching System: Improved parameter precomputation and caching

    • param_precompute context manager for efficient parameter transformation caching
    • cache() method for retrieving cached parameter values
    • Support for custom precompute functions
    • Automatic cache invalidation and management
    • 391 comprehensive tests for caching behavior
  • Hierarchical Parameter Data (brainstate.nn.HiData): New module for structured parameter organization

    • define_param_data() method for declaring hierarchical parameter structures
    • Support for nested parameter groups
    • Improved parameter surgery and manipulation
    • Enhanced type hints and documentation

Module System Enhancements

  • ModuleMapper: New helper for vectorized module operations (formerly Vmap2Module)

    • Simplified API for applying vmap2 to module methods
    • Automatic state management for vectorized operations
    • Consistent interface with Vmap2ModuleCaller
    • Comprehensive documentation with usage examples
  • Enhanced Module Methods:

    • parameters(): Iterate over all parameters in the module hierarchy
    • named_parameters(): Iterate over parameters with their qualified names
    • children(): Access direct child modules
    • named_children(): Access child modules with names
    • init_all_states(): Initialize states with additional keyword arguments
    • Improved Sequential with extend() and insert() methods

Delay Mechanisms

  • Frequency-Controlled Updates: Enhanced Delay class with flexible update strategies

    • update_every parameter: Control how often delay buffers are updated
    • Support for integer steps (update every N steps)
    • Support for time-based updates with physical units (e.g., 1*ms)
    • Automatic handling of unit conversions and validation
    • Comprehensive tests covering various update strategies
  • Unified Delay Implementation: Refactored delay mechanism

    • Ring buffer implementation for efficient historical value storage
    • Support for linear interpolation
    • Better handling of multi-dimensional inputs
    • Improved integration with neural network modules

Regularization

  • Comprehensive Regularization Module (brainstate.nn._regularization, 2840 lines):

    • Complete suite of regularization techniques
    • L1, L2, and elastic net regularization
    • Dropout variants
    • Weight decay and other parameter constraints
    • 1261 tests for regularization functionality
  • Transform Module (brainstate.nn._transform, 1661 lines):

    • Advanced parameter transformations
    • Quantization support
    • Normalization techniques
    • Integration with caching system
    • 452 comprehensive tests

Transformations

Vectorization and Parallelization

  • Mapping Function Refactoring: Reorganized mapping implementations

    • Renamed _mapping.py_mapping2.py (primary vmap2 implementation)
    • Renamed _mapping_old.py_mapping1.py (legacy vmap implementation)
    • Added _mapping3.py: New pmap2 implementation for parallelization
    • vmap2_new_states: Helper for creating new states in vectorized operations
    • Relaxed return type requirements for more flexible mapping functions
  • Enhanced Documentation: Updated tutorials and API documentation

    • Comprehensive vmap2 tutorial with practical examples
    • Enhanced parallelization documentation for pmap2
    • Updated state management guides
    • Expanded gradient transformation documentation

Compatibility and Utilities

JAX Compatibility

  • Enhanced JAX Integration: Improved compatibility with newer JAX versions
    • Updated backend import for JAX version detection
    • Enhanced get_aval function for JAX version compatibility
    • Standardized jit_named_scope arguments
    • Support for JAX 0.8.0+ in CI configuration

Utility Functions

  • Dataclass Support: Added is_dataclass utility function in brainstate.util.struct

    • Robust dataclass type checking
    • Better handling of dataclass-based structures
  • Tracer Utilities: New _tracers.py module for JAX tracer handling

    • current_jax_trace(): Get current JAX trace context with version compatibility
    • Helper functions for working with JAX abstract values

Graph Operations

  • Context Management (brainstate.graph._context):

    • New context management system for graph operations (119 lines)
    • TraceContextError: Specialized error class for tracing issues
    • Enhanced state tracking during graph construction
    • 64 tests for context management
  • Conversion Utilities (brainstate.graph._convert):

    • New conversion utilities for graph operations (278 lines)
    • Better handling of graph transformations
    • Improved node conversion logic

Random Number Generation

  • Enhanced RandomState: Improved random number generation
    • Better compatibility with newer JAX versions (98 lines of improvements)
    • Enhanced state management for random keys
    • Improved thread safety
    • Better error messages and validation

Documentation

  • Comprehensive API Documentation: Expanded documentation across all modules

    • brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)
    • environ.rst: Added 48 lines of documentation for environment state and keys
    • nn.rst: Added 222 lines documenting neural network components
    • transform.rst: Added 132 lines for gradient transformations and mapping functions
  • Tutorial Updates:

    • Updated vectorization tutorial to reflect vmapvmap2 transition
    • Enhanced examples with ModuleMapper usage
    • Improved state management examples

Breaking Changes

  • Renamed Functions and Classes:

    • ParaMParam
    • ConstMConst
    • vmapvmap2 (old vmap preserved in _mapping1.py for compatibility)
    • pmappmap2
    • _param_data_hidata
  • Parameter Naming Standardization:

    • fit_parfit across all modules
    • brainscalebraintrace in example files
  • Method Signature Changes:

    • init_all_states() now accepts additional keyword arguments
    • param_precompute() signature updated to support caching and custom functions
    • Module initialization methods enhanced with keyword argument support

Testing

  • Comprehensive Test Coverage: Added 4,000+ lines of new tests
    • Thread safety tests: 346 tests ensuring thread-safe operations
    • Hook system tests: 320 tests for state hooks
    • State management tests: 924 tests expanded coverage
    • Parameter caching tests: 391 tests for caching behavior
    • Delay mechanism tests: 244 tests for delay functionality
    • HiData tests: 463 tests for hierarchical data structures
    • Module tests: 661 tests expanded coverage
    • Regularization tests: 1,261 tests
    • Transform tests: 452 tests
    • Mapping tests: Updated for vmap2 and pmap2

Bug Fixes

  • Fixed cache key handling in state management
  • Improved error messages for missing states in gradient transformations
  • Enhanced validation for delay update frequency
  • Corrected import paths for better module organization
  • Fixed compatibility issues with JAX 0.8.0+

Internal Changes

  • Reorganized import statements across all modules for clarity
  • Enhanced type hints throughout the codebase
  • Improved code documentation with comprehensive docstrings
  • Streamlined module exports in __all__ definitions
  • Better separation of concerns in module organization

What's Changed

  • Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
  • Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
  • Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
  • Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
  • Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130

Full Changelog: v0.2.8...v0.2.9

Version 0.2.8

19 Dec 06:21
ac51f5f

Choose a tag to compare

This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.

Compatibility

  • JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses jax.make_jaxpr directly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.

Breaking Changes

  • Removed abstracted_axes parameter: The abstracted_axes parameter has been removed from:
    • StatefulFunction.__init__
    • StatefulMapping.__init__
    • make_jaxpr function
    • _make_jaxpr internal function

Improvements

  • Debug mode support: Added debug_call method to StatefulFunction for proper execution when jax.config.jax_disable_jit is enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation.

  • Lazy loading optimization: RandomState import in the _mapping module is now lazily loaded via _import_rand_state(), improving initial import performance and reducing circular dependency issues.

Internal Changes

  • Removed unused imports (annotate, api_boundary from jax._src) at module level; now imported only where needed
  • Removed internal helper functions _broadcast_prefix and _flat_axes_specs
  • Simplified _abstractify function by removing abstracted axes handling
  • Updated example files to reflect API changes

What's Changed

Full Changelog: v0.2.7...v0.2.8

Version 0.2.7

15 Dec 12:23
cdfe726

Choose a tag to compare

BrainState 0.2.7 modernizes the experimental compilation stack, deepens the transformation APIs, and tightens runtime infrastructure across the project.

Experimental Compiler and Visualization

  • Introduced the experimental neuroir compiler built on dataclass-based graph IR elements and an explicit CompilationContext, improving dependency tracking, hidden-state mapping, and ClosedJaxpr fidelity even for self-connections and delay buffers.
  • Added GraphDisplayer and TextDisplayer backends with hierarchical and force-directed layouts, plus richer diagnostics and tests that cover large sample networks and neuro-graph visualizations.

Transformations and Autodiff

  • Added the jit_named_scope decorator and supporting utilities so nested transformations emit meaningful names inside traced functions, together with _make_jaxpr refinements that separate dynamic/static arguments and improve caching semantics for StatefulFunction.
  • Expanded the gradient toolkit by exporting the new Jacobian (forward and reverse), Hessian, and SOFO transforms, unifying gradient handling for classes, auxiliary returns, and state-aware updates through the transform module.

State and Runtime Enhancements

  • Replaced the experimental ArrayParam with a dedicated DelayState, propagating the new state through the compiler, delay modules, and neuro-IR so historical buffers participate in tracing and optimization just like other states.
  • Environment helpers can now run against injected EnvironmentState instances, enabling sandboxed or per-thread configurations while DelayState-aware unit tests extend coverage of the updated modules.

Experimental and Infrastructure Updates

  • Completed the neuron IR → neuroir rename, aligned the GDiist BPU codebase with the new terminology, and added new sample networks plus placeholder skips to keep the growing compiler/displayer test surface manageable.
  • Added braincell to the development requirements, refreshed documentation wording, and kept CI dependencies current for the GitHub Actions runners.

Bug Fixes

  • Hardened caching, randomness, and initialization logic by fixing get_arg_cache_key, removing stale decorator parameters, validating truncated normal draws, and correcting the exported version metadata.
  • Declared Python 3.14 support and cleaned up compiler import ordering to keep linting noise low.

What's Changed

Full Changelog: v0.2.4...v0.2.7

Version 0.2.4

31 Oct 14:01

Choose a tag to compare

This release introduces the new ArrayParam state type for parameter arrays with custom transformations, experimental BPU backend export support, enhanced JAXPR optimization capabilities, and improved module organization.

New Features

ArrayParam State Type

  • ArrayParam Class: New state type for managing parameter arrays with advanced transformation control
    • Supports custom transformations (e.g., quantization, normalization) that preserve array identity
    • Enables vmap, pmap, and other JAX transformations to correctly handle stateful parameters
    • Provides identity() method that returns the raw array without applying custom transformations
    • Integrates seamlessly with existing State management infrastructure
    • Useful for implementing quantization-aware training and other advanced parameter manipulations
    • Comprehensive documentation with usage examples and best practices

Experimental BPU Backend Export (brainstate.experimental.gdiist_bpu)

  • BPU Backend Export Support: Complete infrastructure for exporting models to GDiist BPU hardware backend (727 lines)

    • export.py: Main export API with to_bpu() function for model conversion
    • parser.py: Operation parser that analyzes JAXPR to identify operations and connections (305 lines)
    • data.py: Data structures and analysis utilities for operation representation (215 lines)
  • Operation Parser Features:

    • Automatic detection of operations from JAXPR equations using brainevent primitives
    • Data flow analysis to identify connections between operations
    • Support for various operation types: slice, add, multiply, and more
    • Detailed analysis output showing equations, inputs, outputs, and connections
  • Analysis and Debugging Tools:

    • display_analysis_results(): Comprehensive visualization of parsed operations
    • Shows operation details including equation count, variable mappings, and connections
    • Displays connection information with producer/consumer operations and variable details
    • Example implementation in examples/400_CUBA_2005_bpu.py

Enhancements

JAXPR Optimization Improvements

  • Enhanced Constant Folding:

    • Better handling of literal values in constant folding optimization
    • Improved detection and elimination of redundant literal operations
    • More efficient constant propagation through computation graphs
  • Identity Equation Optimization:

    • Optimized handling of Literal outputs to avoid unnecessary bridging equations
    • Improved identity equation creation for interface preservation
    • Better handling of edge cases in optimization passes
  • Error Handling:

    • Added fallback source info utility for better error messages
    • Fixed potential NoneType errors in equation handling
    • Improved validation of optimization results

State Management

  • Enhanced State Tests: Comprehensive test refactoring with improved coverage (454 tests)
    • Better organization of state type tests
    • More thorough validation of state behavior
    • Enhanced test readability and maintainability

What's Changed

Full Changelog: v0.2.3...v0.2.4

Version 0.2.3

13 Oct 14:47
2235c36

Choose a tag to compare

his release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.

New Features

IR Optimization (brainstate.transform._ir_optim)

  • Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs

    • constant_fold: Evaluates constant expressions at compile time, reducing runtime computation
    • dead_code_elimination: Removes equations whose outputs are unused, reducing computation overhead
    • common_subexpression_elimination: Identifies and reuses results of identical computations
    • copy_propagation: Eliminates unnecessary copy operations by propagating original variables
    • algebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)
    • optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode
  • IdentitySet Class: Custom set implementation using object identity (id()) instead of equality

    • Enables proper handling of JAX variables and Literals in optimization passes
    • Implements MutableSet interface with full collection protocol support
    • Essential for tracking variable usage without relying on equality comparisons

Optimization Features

  • Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)

    • Identity equations automatically added when needed to maintain correct interfaces
    • Uses convert_element_type primitive with matching dtypes as identity operation
    • Ensures optimized functions remain drop-in replacements
  • Optimization Pipeline: Configurable multi-pass optimization with convergence detection

    • Customizable optimization sequence via optimizations parameter
    • Automatic convergence detection when no more reductions possible
    • Maximum iteration control with max_iterations parameter
    • Verbose mode with detailed statistics and progress tracking
  • JAX Integration: Full support for JAX primitives and special cases

    • Blacklist for primitives that shouldn't be folded (broadcast_in_dim, broadcast)
    • Proper handling of closed_call and scan primitives
    • Support for both Jaxpr and ClosedJaxpr inputs

State Management Refactoring (brainstate.transform._mapping)

  • Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)

    • Enhanced state management with improved axis tracking
    • Better error messages and validation
    • Streamlined state value restoration logic
  • Old vmap Implementation Preserved (_mapping_old.py, 579 lines): Legacy vmap with explicit state management

    • Exports original vmap and vmap_new_states functions
    • Maintains backward compatibility for existing code
    • Specialized for stateful functions with explicit state parameters

Documentation

API Documentation

  • transform.rst: Added comprehensive IR Optimization section (24 lines)

    • Detailed module description explaining compiler optimizations
    • All 6 optimization functions documented with autosummary
    • Clear explanation of benefits: reduced computation overhead, improved runtime performance
    • Positioned between Compilation Tools and Gradient Computations sections
  • NumPy-style Docstrings: All optimization functions include:

    • Comprehensive parameter descriptions with types and defaults
    • Detailed return value documentation
    • Notes sections explaining preservation of function interfaces
    • Multiple practical examples demonstrating usage
    • Algorithm descriptions for complex optimizations
    • Cross-references between related functions

Enhancements

Optimization Pipeline

  • Progress Tracking: Verbose mode shows equation count changes after each optimization

    • Displays initial, intermediate, and final equation counts
    • Shows reduction statistics with percentages
    • Indicates convergence detection
    • Reports iteration counts
  • Validation: Runtime checks ensure optimization correctness

    • Verifies input variables unchanged after optimization
    • Validates output variables preserved
    • Raises clear errors if interface violated
    • Checks for valid optimization names
  • Flexibility: Customizable optimization sequences

    • Apply all optimizations in recommended order (default)
    • Select specific optimizations only
    • Control iteration limits
    • Toggle verbose output

JAX Integration

  • JaxprEqn Construction: Proper handling of required ctx parameter

    • Uses JaxprEqnContext(None, True) for identity equations
    • Ensures compatibility with JAX internal API
    • Maintains proper equation structure
  • Primitive Handling: Special cases for JAX primitives

    • Blacklist for primitives that shouldn't be optimized
    • Proper parameter extraction and validation
    • Support for effects and source_info fields

Bug Fixes

  • Fixed JaxprEqn constructor calls to include required ctx parameter (7th positional argument)
  • Corrected import paths for vmap2 in test files and tutorials
  • Fixed RandomState.uniform() calls to use size parameter instead of shape
  • Enhanced test assertions for proper state axis handling
  • Improved error messages for batch axis mismatches

Refactoring

Transform Module

  • Renamed Files:

    • vmapvmap2 in _mapping.py
    • Preserved original vmap in _mapping_old.py for compatibility
  • Module Exports: Updated __init__.py to export both old and new vmap implementations

    • vmap from _mapping_old.py (legacy)
    • vmap2 from _mapping.py (new)
    • vmap_new_states from both modules

What's Changed

  • Introduce JAXPR optimizations and enhance stateful mapping by @chaoming0625 in #108

Full Changelog: v0.2.2...v0.2.3

Version 0.2.2

12 Oct 15:36
74c0b31

Choose a tag to compare

This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.

New Features

Hidden State Classes

  • HiddenGroupState: New class for managing multiple hidden states within a single array

    • Stores multiple states in the last dimension of a single array
    • Provides get_value() and set_value() methods for accessing individual states by index or name
    • Optimized for LSTM-style architectures with multiple hidden components (h, c)
    • Includes name2index mapping for convenient state access
  • HiddenTreeState: New class for managing multiple hidden states with different physical units

    • Supports PyTree structure (dict or sequence) of hidden states
    • Preserves physical units (e.g., voltage, current, conductance) via brainunit integration
    • Provides name2unit and index2unit mappings for unit tracking
    • Ideal for neuroscience models with heterogeneous state variables
    • Maintains compatibility with BrainScale online learning

State Utilities

  • maybe_state: New utility function for flexible value extraction
    • Extracts values from State objects automatically
    • Returns non-State values unchanged
    • Simplifies writing functions that accept both states and raw values

Enhancements

State Classes

  • HiddenState: Enhanced documentation and type checking

    • Restricted to numpy.ndarray, jax.Array, and brainunit.Quantity types only
    • Added comprehensive docstrings with examples
    • Clarified equivalence to brainscale.ETraceState for online learning
    • Improved error messages for invalid input types
  • BatchState: Now properly exported in the public API

    • Available via brainstate.BatchState
    • Enhanced documentation for batch data management

Documentation

  • API Reference: Completely reorganized brainstate.rst documentation

    • Organized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules
    • Added detailed descriptions for each section and subsection
    • Included comprehensive bullet-point summaries for all APIs
    • Enhanced deprecation warnings with clear migration paths
    • Added module-level descriptions for all submodules
  • State Classes: Enhanced documentation for all state types

    • Added detailed use case descriptions
    • Included practical examples for each state type
    • Clarified semantic distinctions between state types
    • Documented integration with JAX transformations
  • JAX Transformations: Improved documentation for stateful transforms

    • Enhanced docstrings for jit, grad, vmap, scan, and other transforms
    • Added examples showing state management patterns
    • Documented state tracing behavior
    • Clarified interaction with StateTraceStack

Transform System

  • Enhanced State Finding: New _find_state.py module for automatic state discovery

    • Improved state detection in nested structures
    • Better handling of state dependencies
    • Enhanced error messages for state-related issues
  • StatefulFunction: Major enhancements to make_jaxpr functionality

    • Improved Jaxpr generation for stateful computations
    • Better handling of state read/write tracking
    • Enhanced debugging support
  • Mapping Transformations: Significant refactoring of vmap and pmap

    • Improved state management across vectorized operations
    • Better handling of state broadcasting
    • Enhanced error reporting for mapping operations

Random Number Generation

  • Module Reorganization: Complete refactoring of random module structure

    • Renamed _rand_funs.py to _fun.py
    • Renamed _rand_seed.py to _seed.py
    • Renamed _rand_state.py to _state.py
    • Extracted distribution implementations to new _impl.py module (691 lines)
  • Improved Random State: Enhanced RandomState class with better state management

    • Simplified implementation (reduced from 534 to ~300 lines)
    • Better integration with JAX's random number generation
    • Improved thread safety and state isolation

Testing

  • Comprehensive Test Suite: Added 102 tests covering all state functionality
    • TestBasicState (13 tests): Core State class operations
    • TestShortTermState (2 tests): Short-term state behavior
    • TestLongTermState (2 tests): Long-term state behavior
    • TestParamState (2 tests): Parameter state usage patterns
    • TestBatchState (2 tests): Batch state functionality
    • TestHiddenState (7 tests): Hidden state with different array types
    • TestHiddenGroupState (9 tests): Multiple hidden state management
    • TestHiddenTreeState (12 tests): PyTree hidden states with units
    • TestFakeState (4 tests): Lightweight state alternative
    • TestStateDictManager (6 tests): State collection management
    • TestStateTraceStack (11 tests): State tracing and recovery
    • TestTreefyState (6 tests): PyTree state references
    • TestContextManagers (6 tests): State context managers
    • TestStateCatcher (8 tests): State catching utilities
    • TestIntegrationScenarios (5 tests): Real-world use cases

Bug Fixes

  • Fixed HiddenGroupState.set_value() to work correctly with JAX arrays
  • Improved error handling in hidden state value validation
  • Enhanced type checking for hidden state initialization

Documentation

Tutorial Reorganization

  • Basics Tutorials: Complete rewrite and expansion

    • 01_getting_started.ipynb: Enhanced introduction with practical examples
    • 02_state_management.ipynb: Comprehensive state management guide
    • 03_random_numbers.ipynb: In-depth random number generation tutorial
  • Neural Networks Tutorials: Restructured and expanded

    • 01_module_basics.ipynb: New comprehensive module system guide
    • 02_basic_layers.ipynb: Enhanced layer documentation with examples
    • 03_activations_normalization.ipynb: Detailed activation and normalization guide
    • 04_recurrent_networks.ipynb: New RNN tutorial with practical examples
    • 05_dynamics_systems.ipynb: New dynamical systems tutorial
  • Examples: Reorganized and enhanced

    • Renamed 10_image_classification.ipynb to 01_image_classification.ipynb
    • Renamed 11_sequence_modeling.ipynb to 02_sequence_modeling.ipynb
    • Added 03_brain_inspired_computing.ipynb: New brain-inspired computing examples
    • Renamed 18_optimization_tricks.ipynb to 04_optimization_tricks.ipynb
    • Renamed 19_model_deployment.ipynb to 05_model_deployment.ipynb
  • Transforms Tutorials: Reorganized for better flow

    • 01_jit_compilation.ipynb: New comprehensive JIT guide
    • 02_automatic_differentiation.ipynb: Enhanced autodiff tutorial
    • 03_vectorization.ipynb: Improved vmap/pmap guide
    • 04_loops_conditions.ipynb: Enhanced control flow guide
    • 05_other_transforms.ipynb: Other transformation utilities
  • Advanced Tutorials: Renumbered for clarity

    • 01_graph_operations.ipynb (formerly 14_graph_operations.ipynb)
    • 02_mixin_system.ipynb (formerly 15_mixin_system.ipynb)
    • 03_typing_system.ipynb (formerly 16_typing_system.ipynb)
    • 04_utilities.ipynb (formerly 17_utilities.ipynb)
  • Migration Guides: Updated and simplified

    • 01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guide
    • Removed outdated BrainPy integration notebook
  • Supplementary: Reorganized

    • 01_performance_optimization.ipynb
    • 02_debugging_tips.ipynb
    • 03_faq.ipynb: Updated FAQ with new content

API Documentation

  • Enhanced module documentation in nn.rst with 306 line improvements
  • Updated transform.rst with new transform APIs
  • Improved environ.rst and graph.rst documentation

Refactoring

  • Removed deprecated eval_shape module and tests
  • Removed deprecated _random.py transform module
  • Cleaned up unused imports across all modules
  • Improved code organization in neural network layers
  • Enhanced type hints and docstrings throughout

Infrastructure

  • Added development dependency for tutorial generation
  • Updated benchmark scripts for performance testing
  • Improved test coverage across transformation modules

What's Changed

New Contributors

Full Changelog: v0.2.0...v0.2.2

Version 0.2.0

08 Oct 08:35
8d9251f

Choose a tag to compare

This is a major release with significant refactoring, new features, and comprehensive documentation improvements.

Breaking Changes

  • Module Deprecations: Deprecated brainstate.augment, brainstate.compile, and brainstate.functional modules in favor of brainstate.transform and brainstate.nn

    • Added deprecation proxies to guide users towards replacement modules
    • Updated all documentation and examples to use new module paths
  • State Management: Replaced write_back_state_values with assign_state_vals_v2 for improved state management

  • Import Path Changes: Major refactoring of import paths across the codebase

    • Moved initialization references to use brainstate.nn
    • Updated random functions to use brainstate.random
    • Standardized imports across all modules
  • Type System: Implemented JointTypes and OneOfTypes generic aliases to enhance type checking and avoid metaclass conflicts

    • Support for subscript syntax
    • Improved type hints across modules
  • Copyright: Updated copyright notices to reflect new ownership by BrainX Ecosystem Limited

New Features

Neural Network Components

  • Transposed Convolution Layers: Complete implementations for upsampling operations

    • ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
    • Support for both channels-first and channels-last data formats via channel_first parameter
    • Configurable stride for controllable upsampling factors
    • Grouped transposed convolution support
    • Automatic padding computation for 'SAME' and 'VALID' modes
  • Convolution Enhancements: Added support for both channels-first and channels-last data formats

    • New channel_first boolean parameter (default: False)
    • PyTorch-compatible format (e.g., [B, C, H, W]) when channel_first=True
    • Default JAX-style format (e.g., [B, H, W, C]) when channel_first=False
  • Padding Layers: Added padding layers for 1D, 2D, and 3D tensors with various modes

  • Unpooling Layers: Added MaxUnpool1d, MaxUnpool2d, and MaxUnpool3d with return_indices support

  • Gradient Utilities: Implemented clip_grad_norm function for gradient clipping in PyTree structures

  • Embedding Enhancements:

    • Added padding_idx, max_norm, and norm_type parameters
    • Improved gradient management with new _contains_tracer function
    • Optimized max_norm application with accessed mask for scaling
  • BatchNorm Improvements: Added feature_axis and track_running_stats parameters

  • LoRA Layer: Added in_size parameter for improved size handling

  • Activation Functions: Added new activation functions and improved signatures

Transform & Compilation

  • StatefulMapping: Introduced for enhanced state management in vmap transformations

  • Mixin Classes: Added Mode, JointMode, Batching, and Training classes for computation behavior control

  • Bounded Cache: Implemented thread-safe bounded cache for JAX Jaxpr with:

    • Comprehensive validation
    • Statistics tracking
    • Enhanced error handling
  • Input Validation: Enhanced input size handling to support numpy integer types

  • Context Parameters: Update method now accepts additional context parameters for improved environment settings

Random & Initialization

  • Dependencies: Integrated braintools for initialization and surrogate gradient functions

    • Updated all initialization references
    • Refactored to use braintools.surrogate for spike functions
  • Random Functions: Replaced uniform_for_unit with jr.uniform for consistency and performance

Utilities & Infrastructure

  • Filter Utilities: Added comprehensive filter utilities for nested structures

  • Pretty Representation: Enhanced pretty_pytree module with:

    • Comprehensive documentation
    • Mapping functions
    • JAX integration
  • Error Handling: Improved state length validation by replacing assertions with ValueError exceptions

  • Collective Operations: Updated function signatures to return target in collective operations

Documentation

  • Comprehensive Docstrings: Added detailed NumPy-style docstrings across all modules

    • Full parameter descriptions with types and default values
    • Multiple practical examples in code blocks
    • Comparison sections highlighting differences from PyTorch
    • Mathematical formulas where applicable
    • References to original papers
    • Best practices and use cases
  • New Documentation Pages:

    • brainstate.environ module documentation
    • brainstate.transform (renamed from compile.rst)
    • Random number generation module
    • Pretty representation module
    • State management tutorial notebook
  • Enhanced Examples: Updated documentation examples to use interactive prompts for clarity

  • Module Descriptions: Enhanced documentation with detailed descriptions, key features, and usage examples

Testing

  • Comprehensive Test Coverage: Added extensive test suites for:

    • _BoundedCache and StatefulFunction
    • brainstate.mixin module
    • brainstate.environ module (context management, precision settings, callbacks)
    • DeprecatedModule and proxy creation functionality
    • Compatible import module
    • Metrics module
    • Node class and helper functions
    • Activation functions with shape and gradient checks
    • Dropout layers
    • Surrogate gradient functions
    • Filter utilities
    • Struct module
    • Pretty representation
  • Test Framework Updates: Refactored tests to use absltest for better JAX compatibility

Refactoring

  • File Reorganization:

    • Renamed metrics.py to _metrics.py
    • Renamed _rate_rnns.py to _rnns.py
    • Renamed _init.py to init.py
    • Reorganized graph module files
    • Cleaned up unused imports and classes
  • Code Quality:

    • Streamlined imports across all modules
    • Enhanced code formatting and whitespace consistency
    • Removed unnecessary inheritance and unused elements
    • Simplified type annotations
    • Improved method signatures for clarity
  • Neuron & Synapse Classes: Refactored to use brainpy module and updated initialization methods

  • Base Classes: Changed base class of EINet and Net from DynamicsGroup to Module for consistency

  • Evaluation Functions: Refactored and updated method names for consistency

Infrastructure

  • Version Bump: Updated version to 0.2.0

  • Development Dependencies: Added braintools to development requirements

  • Issue Templates: Added bug report and feature request templates for improved issue tracking

  • CI/CD: Refactored CI configurations to update pip installation commands

  • Git Ignore: Updated to exclude example figures directory and build artifacts

Bug Fixes

  • Enhanced delay handling for multi-dimensional inputs
  • Fixed gradient function references
  • Improved deprecation handling in tests
  • Fixed precision checks in complex number handling

New Contributors

Full Changelog: v0.1.10...v0.2.0

Version 0.1.10

15 Sep 05:53
877fb86

Choose a tag to compare

What's Changed

  • Fix precision checks in _get_complex to use 'in' for list membership by @chaoming0625 in #96
  • ⬆️ Bump actions/setup-python from 5 to 6 by @dependabot[bot] in #98
  • ⬆️ Bump actions/download-artifact from 4 to 5 by @dependabot[bot] in #97
  • Enhance StateWithDelay class with detailed documentation and add dela… by @chaoming0625 in #99
  • Bump version to 0.1.10 by @chaoming0625 in #100

Full Changelog: v0.1.9...v0.1.10