Releases: chaobrain/brainstate
Version 0.3.0
This release delivers on-device NaN debugging, a unified compilation cache, simplified JAX compatibility, and major internal cleanup — with a net reduction of ~1,800 lines of code. It raises the minimum requirements to Python 3.11 and JAX 0.6.0.
Breaking Changes
- Python >= 3.11 required: Dropped support for Python 3.10. The
requires-pythonfield and classifiers now start at 3.11. - JAX >= 0.6.0 required: All dependency groups (
cpu,cuda12,cuda13,tpu,testing) now mandatejax>=0.6.0. - Unified compilation cache in
StatefulFunction: The four separate internal caches (_cached_jaxpr,_cached_out_shapes,_cached_jaxpr_out_tree,_cached_state_trace) have been consolidated into a single_compilation_cachestoring_CachedCompilationobjects.get_cache_stats()now returns{'compilation_cache': {...}}instead of four individual entries. - Immutable
CacheKeyreplaceshashabledict:get_arg_cache_key()now returns aCacheKey(NamedTuple) instead of the mutablehashabledict. Code that directly inspected or constructed cache keys must be updated. - Removed internal
_make_jaxprfunction: The custom tracing implementation has been deleted in favor of usingjax.make_jaxpr()directly (available in JAX >= 0.6.0). - Removed
debug_depthanddebug_contextfromGradientTransform: Thedepthandcontextparameters for NaN debugging no longer exist following the debug module rewrite. - Removed
breakpoint_iffunction: The conditional breakpoint helper has been removed frombrainstate.transform._debug. - Removed
extend_axis_env_ndfrom compatible imports: This compatibility shim is no longer exported.
New Features
On-Device NaN/Inf Detection
- Complete rewrite of the NaN debugging system (
brainstate.transform._debug). NaN checking now runs on-device via JAX primitives rather than pulling data to the host, providing significantly better performance. - Uses
jax.debug.callbackwith thread-local storage to collect and report NaN findings. - Error tracebacks now point to the user's source code via
source_info_util.user_context, producing IDE-clickable source locations extracted from jaxpr equations. - Recursive instrumentation of nested primitives (
jit,cond,while,scan) for comprehensive NaN detection throughout the computation graph. - More compact and informative error messages via
_format_nan_message().
JAX Traceback Filtering
- Registered brainstate with JAX's
traceback_util.register_exclusion()so internal frames are hidden in user-facing error tracebacks. Follows the same pattern as Flax, Equinox, and other JAX ecosystem libraries. - Users can still see full tracebacks via
JAX_TRACEBACK_FILTERING=off.
State Validation at Call Time
- New
_validate_state_shapes()method checks that current state shapes and dtypes match those recorded at compile time. StatefulFunction.__call__()automatically validates before execution, catching state shape mismatches early with clear error messages.- Added
static_argnumsbounds validation —make_jaxpr()now raisesValueErrorif indices exceed the number of positional arguments.
New Compatible Import
- Added
mapped_avalimport with version-based routing:jax.core.mapped_avalfor JAX < 0.8.2,jax.extend.core.mapped_avalfor >= 0.8.2.
Improvements
- Atomic cache writes: Compilation results are only stored on success, eliminating partial cache entries on error. Uses a double-checked locking pattern for thread safety during compilation.
- Better cache key hashing: Dynamic args/kwargs are now flattened via
jax.tree.flatten()before hashing, fixing non-deterministic hashing issues with custom pytree nodes (e.g.,Quantity). - Modern Python type annotations: Migrated from
typing.Tuple,typing.List,typing.Dict,typing.Optional,typing.Unionto built-intuple,list,dict,X | None,X | Ysyntax across the codebase. - IR visualization compatibility: Replaced direct
jax.core.Xreferences with compatible imports (Var,ClosedJaxpr,Jaxpr,JaxprEqn,Literal,DropVar) in the IR visualizer. - Deterministic error reporting:
jax.debug.callbackin_error_if.pynow usesordered=Truefor deterministic error callback ordering. - Graph operations cleanup: Major refactoring of
_operation.py,_node.py,_convert.py, and_context.pywith streamlined docstrings, better thread-safety documentation, and cleaner context managers.
Bug Fixes
- Fixed
Delay.__init__initialization order:update_everyis now initialized beforeregister_entryis called, preventing attribute errors during entry registration (#135). - Fixed
graph_to_treeprivate attribute access: Replaced internal_mappingaccess with public API usage in_convert.py.
Internal Changes
- Massive docstring reduction across the graph module (~1,000+ lines removed), replacing verbose multi-paragraph docstrings with concise descriptions.
- Cleaned up TypeVar usage: removed unused
CandNamesaliases, renamedNodeTypeVar toN, removedHashablebound from type variables. - Removed unused tests (
test_all_exports,test_function_imports_availability) from compatible import tests. - Rewrote debug and make_jaxpr test suites to match the new APIs.
- IR optimization imports are now lazy-loaded inside
make_jaxpr()only whenir_optimizationsis configured.
CI/CD
- Bumped
actions/upload-artifactfrom v6 to v7. - Bumped
actions/download-artifactfrom v7 to v8.
What's Changed
- fix(nn): initialize update_every before register_entry by @Routhleck in #135
- deps(deps): bump actions/upload-artifact from 6 to 7 by @dependabot[bot] in #133
- deps(deps): bump actions/download-artifact from 7 to 8 by @dependabot[bot] in #132
- Simplify JAX compat: use jax.make_jaxpr and aval helpers by @chaoming0625 in #137
- Refactor graph ops, update JAX/Python requirements, improve tests by @chaoming0625 in #138
- Add on-device NaN debugging and unify StatefulFunction cache by @chaoming0625 in #139
Full Changelog: v0.2.10...v0.3.0
Version 0.2.10
This release introduces a comprehensive NaN debugging system for gradient computations, refactors the module mapping API for improved clarity, and adds graph context utilities for advanced state management.
New Features
NaN Debugging System
-
JIT-Compatible NaN/Inf Debugging: New debugging utilities for identifying NaN and Inf values during gradient computations
debug_nan: Analyze a function for NaN/Inf values with detailed reportingdebug_nan_if: Conditional NaN debugging with predicate-based activation- Full JIT compatibility for seamless integration into compiled workflows
- Support for debugging NaN in
whileandscanprimitives - Detailed analysis output including variable names, shapes, and affected indices
-
Gradient Function Integration: Added
debug_nanparameter to gradient transformation functionsgrad: Enable NaN debugging during gradient computationvector_grad: NaN debugging for vectorized gradientsjacobianandjacobian_reverse: NaN debugging for Jacobian computationshessian: NaN debugging for Hessian computations
-
Breakpoint Utility: New
breakpointfunction for conditional debugging- Wraps
jax.debug.breakpointwith predicate support - Only triggers when the specified condition is True
- Wraps
API Changes
Module System
-
Renamed
ModuleMappertoMap: Simplified naming for the vectorized module wrapperMapprovides vectorized (vmap2) and parallel (pmap2) mapping over modulesModuleMapperretained as a deprecated alias for backward compatibility- Internal
_ModuleMapperCallingrenamed to_MapCallerfor consistency
-
Enhanced
Map.map()Method: Now accepts callable functions for flexible mapping operations
Bug Fixes
- Fixed
get_backendimport for JAX version compatibility across different JAX releases - Removed
abstractmethoddecorators fromRegularizationclass to allow proper instantiation - Cleaned up unused imports in module initialization files
Internal Changes
- Added comprehensive test suite for NaN debugging (
_debug_test.py, 938 lines) - Removed deprecated
_mapping3.pymodule and associated tests - Streamlined module exports in
__init__.pyfiles
Version 0.2.9
This release introduces a powerful state hook system for advanced state management, refactors neural network modules with enhanced parameter handling, and improves delay mechanisms with frequency-controlled updates.
State Management
State Hook System
-
Global Hook Infrastructure: Comprehensive hook system for intercepting state operations
register_read_hook: Register hooks that execute when state values are readregister_write_hook: Register hooks that execute when state values are writtenregister_restore_hook: Register hooks that execute when state values are restoredHookManager: Thread-safe manager for organizing and executing hooks with priority supportHookContext: Context manager for scoped hook registration and execution- Enables advanced use cases: logging, debugging, value transformation, validation
-
Enhanced State Class: Improved state management with hook integration
- Automatic hook execution on read/write operations
- Better cache key handling for improved performance
- Enhanced thread safety and context management
- Comprehensive test coverage (346 tests for thread safety, 320 tests for hooks)
Neural Network Components
Parameter Management (brainstate.nn.Param and brainstate.nn.Const)
-
Renamed Classes: Simplified naming convention
ParaM→Param: Trainable parameter wrapperConstM→Const: Non-trainable constant wrapper
-
Enhanced Caching System: Improved parameter precomputation and caching
param_precomputecontext manager for efficient parameter transformation cachingcache()method for retrieving cached parameter values- Support for custom precompute functions
- Automatic cache invalidation and management
- 391 comprehensive tests for caching behavior
-
Hierarchical Parameter Data (
brainstate.nn.HiData): New module for structured parameter organizationdefine_param_data()method for declaring hierarchical parameter structures- Support for nested parameter groups
- Improved parameter surgery and manipulation
- Enhanced type hints and documentation
Module System Enhancements
-
ModuleMapper: New helper for vectorized module operations (formerly
Vmap2Module)- Simplified API for applying
vmap2to module methods - Automatic state management for vectorized operations
- Consistent interface with
Vmap2ModuleCaller - Comprehensive documentation with usage examples
- Simplified API for applying
-
Enhanced Module Methods:
parameters(): Iterate over all parameters in the module hierarchynamed_parameters(): Iterate over parameters with their qualified nameschildren(): Access direct child modulesnamed_children(): Access child modules with namesinit_all_states(): Initialize states with additional keyword arguments- Improved
Sequentialwithextend()andinsert()methods
Delay Mechanisms
-
Frequency-Controlled Updates: Enhanced
Delayclass with flexible update strategiesupdate_everyparameter: Control how often delay buffers are updated- Support for integer steps (update every N steps)
- Support for time-based updates with physical units (e.g.,
1*ms) - Automatic handling of unit conversions and validation
- Comprehensive tests covering various update strategies
-
Unified Delay Implementation: Refactored delay mechanism
- Ring buffer implementation for efficient historical value storage
- Support for linear interpolation
- Better handling of multi-dimensional inputs
- Improved integration with neural network modules
Regularization
-
Comprehensive Regularization Module (
brainstate.nn._regularization, 2840 lines):- Complete suite of regularization techniques
- L1, L2, and elastic net regularization
- Dropout variants
- Weight decay and other parameter constraints
- 1261 tests for regularization functionality
-
Transform Module (
brainstate.nn._transform, 1661 lines):- Advanced parameter transformations
- Quantization support
- Normalization techniques
- Integration with caching system
- 452 comprehensive tests
Transformations
Vectorization and Parallelization
-
Mapping Function Refactoring: Reorganized mapping implementations
- Renamed
_mapping.py→_mapping2.py(primaryvmap2implementation) - Renamed
_mapping_old.py→_mapping1.py(legacyvmapimplementation) - Added
_mapping3.py: Newpmap2implementation for parallelization vmap2_new_states: Helper for creating new states in vectorized operations- Relaxed return type requirements for more flexible mapping functions
- Renamed
-
Enhanced Documentation: Updated tutorials and API documentation
- Comprehensive
vmap2tutorial with practical examples - Enhanced parallelization documentation for
pmap2 - Updated state management guides
- Expanded gradient transformation documentation
- Comprehensive
Compatibility and Utilities
JAX Compatibility
- Enhanced JAX Integration: Improved compatibility with newer JAX versions
- Updated backend import for JAX version detection
- Enhanced
get_avalfunction for JAX version compatibility - Standardized
jit_named_scopearguments - Support for JAX 0.8.0+ in CI configuration
Utility Functions
-
Dataclass Support: Added
is_dataclassutility function inbrainstate.util.struct- Robust dataclass type checking
- Better handling of dataclass-based structures
-
Tracer Utilities: New
_tracers.pymodule for JAX tracer handlingcurrent_jax_trace(): Get current JAX trace context with version compatibility- Helper functions for working with JAX abstract values
Graph Operations
-
Context Management (
brainstate.graph._context):- New context management system for graph operations (119 lines)
TraceContextError: Specialized error class for tracing issues- Enhanced state tracking during graph construction
- 64 tests for context management
-
Conversion Utilities (
brainstate.graph._convert):- New conversion utilities for graph operations (278 lines)
- Better handling of graph transformations
- Improved node conversion logic
Random Number Generation
- Enhanced RandomState: Improved random number generation
- Better compatibility with newer JAX versions (98 lines of improvements)
- Enhanced state management for random keys
- Improved thread safety
- Better error messages and validation
Documentation
-
Comprehensive API Documentation: Expanded documentation across all modules
brainstate.rst: Reorganized with improved structure (21 lines removed, refactored into submodules)environ.rst: Added 48 lines of documentation for environment state and keysnn.rst: Added 222 lines documenting neural network componentstransform.rst: Added 132 lines for gradient transformations and mapping functions
-
Tutorial Updates:
- Updated vectorization tutorial to reflect
vmap→vmap2transition - Enhanced examples with
ModuleMapperusage - Improved state management examples
- Updated vectorization tutorial to reflect
Breaking Changes
-
Renamed Functions and Classes:
ParaM→ParamConstM→Constvmap→vmap2(oldvmappreserved in_mapping1.pyfor compatibility)pmap→pmap2_param_data→_hidata
-
Parameter Naming Standardization:
fit_par→fitacross all modulesbrainscale→braintracein example files
-
Method Signature Changes:
init_all_states()now accepts additional keyword argumentsparam_precompute()signature updated to support caching and custom functions- Module initialization methods enhanced with keyword argument support
Testing
- Comprehensive Test Coverage: Added 4,000+ lines of new tests
- Thread safety tests: 346 tests ensuring thread-safe operations
- Hook system tests: 320 tests for state hooks
- State management tests: 924 tests expanded coverage
- Parameter caching tests: 391 tests for caching behavior
- Delay mechanism tests: 244 tests for delay functionality
- HiData tests: 463 tests for hierarchical data structures
- Module tests: 661 tests expanded coverage
- Regularization tests: 1,261 tests
- Transform tests: 452 tests
- Mapping tests: Updated for
vmap2andpmap2
Bug Fixes
- Fixed cache key handling in state management
- Improved error messages for missing states in gradient transformations
- Enhanced validation for delay update frequency
- Corrected import paths for better module organization
- Fixed compatibility issues with JAX 0.8.0+
Internal Changes
- Reorganized import statements across all modules for clarity
- Enhanced type hints throughout the codebase
- Improved code documentation with comprehensive docstrings
- Streamlined module exports in
__all__definitions - Better separation of concerns in module organization
What's Changed
- Enhance random utils and dataclass helpers for newer JAX by @chaoming0625 in #126
- Add State hook system and refactor nn modules and transforms by @chaoming0625 in #127
- Update vectorization docs for vmap2 and relax mapping return type by @chaoming0625 in #128
- Refactor Param and delay APIs and add ModuleMapper/pmap2 helpers by @chaoming0625 in #129
- Enhance Delay with frequency-controlled updates and unit-aware timing by @chaoming0625 in #130
Full Changelog: v0.2.8...v0.2.9
Version 0.2.8
This release ensures compatibility with JAX 0.8.2+ and removes the experimental module that was superseded by upstream changes.
Compatibility
- JAX 0.8.2+ Support: Added compatibility with JAX version 0.8.2 and later. The library now uses
jax.make_jaxprdirectly for JAX >= 0.8.2 while maintaining backward compatibility with earlier versions.
Breaking Changes
- Removed
abstracted_axesparameter: Theabstracted_axesparameter has been removed from:StatefulFunction.__init__StatefulMapping.__init__make_jaxprfunction_make_jaxprinternal function
Improvements
-
Debug mode support: Added
debug_callmethod toStatefulFunctionfor proper execution whenjax.config.jax_disable_jitis enabled. This improves debugging workflows by allowing stateful functions to execute without JIT compilation. -
Lazy loading optimization:
RandomStateimport in the_mappingmodule is now lazily loaded via_import_rand_state(), improving initial import performance and reducing circular dependency issues.
Internal Changes
- Removed unused imports (
annotate,api_boundaryfromjax._src) at module level; now imported only where needed - Removed internal helper functions
_broadcast_prefixand_flat_axes_specs - Simplified
_abstractifyfunction by removing abstracted axes handling - Updated example files to reflect API changes
What's Changed
- fix: compatiable with
jax>=0.8.2by @chaoming0625 in #124 - chore(changelog): update release notes for version 0.2.8 by @chaoming0625 in #125
Full Changelog: v0.2.7...v0.2.8
Version 0.2.7
BrainState 0.2.7 modernizes the experimental compilation stack, deepens the transformation APIs, and tightens runtime infrastructure across the project.
Experimental Compiler and Visualization
- Introduced the experimental
neuroircompiler built on dataclass-based graph IR elements and an explicitCompilationContext, improving dependency tracking, hidden-state mapping, and ClosedJaxpr fidelity even for self-connections and delay buffers. - Added GraphDisplayer and TextDisplayer backends with hierarchical and force-directed layouts, plus richer diagnostics and tests that cover large sample networks and neuro-graph visualizations.
Transformations and Autodiff
- Added the
jit_named_scopedecorator and supporting utilities so nested transformations emit meaningful names inside traced functions, together with_make_jaxprrefinements that separate dynamic/static arguments and improve caching semantics forStatefulFunction. - Expanded the gradient toolkit by exporting the new Jacobian (forward and reverse), Hessian, and SOFO transforms, unifying gradient handling for classes, auxiliary returns, and state-aware updates through the transform module.
State and Runtime Enhancements
- Replaced the experimental
ArrayParamwith a dedicatedDelayState, propagating the new state through the compiler, delay modules, and neuro-IR so historical buffers participate in tracing and optimization just like other states. - Environment helpers can now run against injected
EnvironmentStateinstances, enabling sandboxed or per-thread configurations while DelayState-aware unit tests extend coverage of the updated modules.
Experimental and Infrastructure Updates
- Completed the neuron IR → neuroir rename, aligned the GDiist BPU codebase with the new terminology, and added new sample networks plus placeholder skips to keep the growing compiler/displayer test surface manageable.
- Added
braincellto the development requirements, refreshed documentation wording, and kept CI dependencies current for the GitHub Actions runners.
Bug Fixes
- Hardened caching, randomness, and initialization logic by fixing
get_arg_cache_key, removing stale decorator parameters, validating truncated normal draws, and correcting the exported version metadata. - Declared Python 3.14 support and cleaned up compiler import ordering to keep linting noise low.
What's Changed
- Add Jaxpr optimization passes and integrate constant folding by @chaoming0625 in #114
- Enhance compilation with BPU parser, IR optimizations, and device wrappers by @chaoming0625 in #115
- deps(deps): bump actions/checkout from 5 to 6 by @dependabot[bot] in #116
- update document by @xinzhu-L in #117
- Support custom EnvironmentState instances in environ helpers by @chaoming0625 in #118
- Add unified gradient transforms with Jacobian, Hessian, SOFO by @chaoming0625 in #119
- Refactor(state): remove ArrayParam, add DelayState by @chaoming0625 in #122
- feat(neuroir): add NeuroIR core module and API by @chaoming0625 in #123
Full Changelog: v0.2.4...v0.2.7
Version 0.2.4
This release introduces the new ArrayParam state type for parameter arrays with custom transformations, experimental BPU backend export support, enhanced JAXPR optimization capabilities, and improved module organization.
New Features
ArrayParam State Type
- ArrayParam Class: New state type for managing parameter arrays with advanced transformation control
- Supports custom transformations (e.g., quantization, normalization) that preserve array identity
- Enables
vmap,pmap, and other JAX transformations to correctly handle stateful parameters - Provides
identity()method that returns the raw array without applying custom transformations - Integrates seamlessly with existing State management infrastructure
- Useful for implementing quantization-aware training and other advanced parameter manipulations
- Comprehensive documentation with usage examples and best practices
Experimental BPU Backend Export (brainstate.experimental.gdiist_bpu)
-
BPU Backend Export Support: Complete infrastructure for exporting models to GDiist BPU hardware backend (727 lines)
export.py: Main export API withto_bpu()function for model conversionparser.py: Operation parser that analyzes JAXPR to identify operations and connections (305 lines)data.py: Data structures and analysis utilities for operation representation (215 lines)
-
Operation Parser Features:
- Automatic detection of operations from JAXPR equations using brainevent primitives
- Data flow analysis to identify connections between operations
- Support for various operation types: slice, add, multiply, and more
- Detailed analysis output showing equations, inputs, outputs, and connections
-
Analysis and Debugging Tools:
display_analysis_results(): Comprehensive visualization of parsed operations- Shows operation details including equation count, variable mappings, and connections
- Displays connection information with producer/consumer operations and variable details
- Example implementation in
examples/400_CUBA_2005_bpu.py
Enhancements
JAXPR Optimization Improvements
-
Enhanced Constant Folding:
- Better handling of literal values in constant folding optimization
- Improved detection and elimination of redundant literal operations
- More efficient constant propagation through computation graphs
-
Identity Equation Optimization:
- Optimized handling of
Literaloutputs to avoid unnecessary bridging equations - Improved identity equation creation for interface preservation
- Better handling of edge cases in optimization passes
- Optimized handling of
-
Error Handling:
- Added fallback source info utility for better error messages
- Fixed potential NoneType errors in equation handling
- Improved validation of optimization results
State Management
- Enhanced State Tests: Comprehensive test refactoring with improved coverage (454 tests)
- Better organization of state type tests
- More thorough validation of state behavior
- Enhanced test readability and maintainability
What's Changed
- deps(deps): bump actions/download-artifact from 5 to 6 by @dependabot[bot] in #109
- deps(deps): bump actions/upload-artifact from 4 to 5 by @dependabot[bot] in #110
- Add ArrayParam and integrate JAXPR optimizations by @chaoming0625 in #112
- Add experimental BPU backend export support by @chaoming0625 in #111
- Standardize module attribution for random and transform by @chaoming0625 in #113
Full Changelog: v0.2.3...v0.2.4
Version 0.2.3
his release introduces powerful IR (Intermediate Representation) optimization capabilities for JAX computation graphs, comprehensive state management refactoring for vectorized mapping operations, and extensive testing infrastructure improvements.
New Features
IR Optimization (brainstate.transform._ir_optim)
-
Intermediate Representation Optimization Module (876 lines): Complete suite of compiler-level optimizations for JAX computation graphs
constant_fold: Evaluates constant expressions at compile time, reducing runtime computationdead_code_elimination: Removes equations whose outputs are unused, reducing computation overheadcommon_subexpression_elimination: Identifies and reuses results of identical computationscopy_propagation: Eliminates unnecessary copy operations by propagating original variablesalgebraic_simplification: Applies algebraic identities (x+0=x, x*1=x, x-x=0, etc.)optimize_jaxpr: Orchestrates multiple optimization passes with configurable iteration and verbose mode
-
IdentitySet Class: Custom set implementation using object identity (
id()) instead of equality- Enables proper handling of JAX variables and Literals in optimization passes
- Implements
MutableSetinterface with full collection protocol support - Essential for tracking variable usage without relying on equality comparisons
Optimization Features
-
Interface Preservation: All optimizations preserve function input/output variables (invars/outvars)
- Identity equations automatically added when needed to maintain correct interfaces
- Uses
convert_element_typeprimitive with matching dtypes as identity operation - Ensures optimized functions remain drop-in replacements
-
Optimization Pipeline: Configurable multi-pass optimization with convergence detection
- Customizable optimization sequence via
optimizationsparameter - Automatic convergence detection when no more reductions possible
- Maximum iteration control with
max_iterationsparameter - Verbose mode with detailed statistics and progress tracking
- Customizable optimization sequence via
-
JAX Integration: Full support for JAX primitives and special cases
- Blacklist for primitives that shouldn't be folded (broadcast_in_dim, broadcast)
- Proper handling of
closed_callandscanprimitives - Support for both Jaxpr and ClosedJaxpr inputs
State Management Refactoring (brainstate.transform._mapping)
-
Renamed vmap to vmap2: Major refactoring of vectorized mapping implementation (647 lines)
- Enhanced state management with improved axis tracking
- Better error messages and validation
- Streamlined state value restoration logic
-
Old vmap Implementation Preserved (
_mapping_old.py, 579 lines): Legacy vmap with explicit state management- Exports original
vmapandvmap_new_statesfunctions - Maintains backward compatibility for existing code
- Specialized for stateful functions with explicit state parameters
- Exports original
Documentation
API Documentation
-
transform.rst: Added comprehensive IR Optimization section (24 lines)
- Detailed module description explaining compiler optimizations
- All 6 optimization functions documented with autosummary
- Clear explanation of benefits: reduced computation overhead, improved runtime performance
- Positioned between Compilation Tools and Gradient Computations sections
-
NumPy-style Docstrings: All optimization functions include:
- Comprehensive parameter descriptions with types and defaults
- Detailed return value documentation
- Notes sections explaining preservation of function interfaces
- Multiple practical examples demonstrating usage
- Algorithm descriptions for complex optimizations
- Cross-references between related functions
Enhancements
Optimization Pipeline
-
Progress Tracking: Verbose mode shows equation count changes after each optimization
- Displays initial, intermediate, and final equation counts
- Shows reduction statistics with percentages
- Indicates convergence detection
- Reports iteration counts
-
Validation: Runtime checks ensure optimization correctness
- Verifies input variables unchanged after optimization
- Validates output variables preserved
- Raises clear errors if interface violated
- Checks for valid optimization names
-
Flexibility: Customizable optimization sequences
- Apply all optimizations in recommended order (default)
- Select specific optimizations only
- Control iteration limits
- Toggle verbose output
JAX Integration
-
JaxprEqn Construction: Proper handling of required
ctxparameter- Uses
JaxprEqnContext(None, True)for identity equations - Ensures compatibility with JAX internal API
- Maintains proper equation structure
- Uses
-
Primitive Handling: Special cases for JAX primitives
- Blacklist for primitives that shouldn't be optimized
- Proper parameter extraction and validation
- Support for effects and source_info fields
Bug Fixes
- Fixed JaxprEqn constructor calls to include required
ctxparameter (7th positional argument) - Corrected import paths for
vmap2in test files and tutorials - Fixed
RandomState.uniform()calls to usesizeparameter instead ofshape - Enhanced test assertions for proper state axis handling
- Improved error messages for batch axis mismatches
Refactoring
Transform Module
-
Renamed Files:
vmap→vmap2in_mapping.py- Preserved original
vmapin_mapping_old.pyfor compatibility
-
Module Exports: Updated
__init__.pyto export both old and new vmap implementationsvmapfrom_mapping_old.py(legacy)vmap2from_mapping.py(new)vmap_new_statesfrom both modules
What's Changed
- Introduce JAXPR optimizations and enhance stateful mapping by @chaoming0625 in #108
Full Changelog: v0.2.2...v0.2.3
Version 0.2.2
This release focuses on enhancing hidden state management for recurrent neural networks and eligibility trace-based learning, along with comprehensive testing and documentation improvements.
New Features
Hidden State Classes
-
HiddenGroupState: New class for managing multiple hidden states within a single array
- Stores multiple states in the last dimension of a single array
- Provides
get_value()andset_value()methods for accessing individual states by index or name - Optimized for LSTM-style architectures with multiple hidden components (h, c)
- Includes
name2indexmapping for convenient state access
-
HiddenTreeState: New class for managing multiple hidden states with different physical units
- Supports PyTree structure (dict or sequence) of hidden states
- Preserves physical units (e.g., voltage, current, conductance) via
brainunitintegration - Provides
name2unitandindex2unitmappings for unit tracking - Ideal for neuroscience models with heterogeneous state variables
- Maintains compatibility with BrainScale online learning
State Utilities
- maybe_state: New utility function for flexible value extraction
- Extracts values from State objects automatically
- Returns non-State values unchanged
- Simplifies writing functions that accept both states and raw values
Enhancements
State Classes
-
HiddenState: Enhanced documentation and type checking
- Restricted to
numpy.ndarray,jax.Array, andbrainunit.Quantitytypes only - Added comprehensive docstrings with examples
- Clarified equivalence to
brainscale.ETraceStatefor online learning - Improved error messages for invalid input types
- Restricted to
-
BatchState: Now properly exported in the public API
- Available via
brainstate.BatchState - Enhanced documentation for batch data management
- Available via
Documentation
-
API Reference: Completely reorganized
brainstate.rstdocumentation- Organized into 6 major sections: Core State Classes, State Management, State Utilities, Error Handling, and Submodules
- Added detailed descriptions for each section and subsection
- Included comprehensive bullet-point summaries for all APIs
- Enhanced deprecation warnings with clear migration paths
- Added module-level descriptions for all submodules
-
State Classes: Enhanced documentation for all state types
- Added detailed use case descriptions
- Included practical examples for each state type
- Clarified semantic distinctions between state types
- Documented integration with JAX transformations
-
JAX Transformations: Improved documentation for stateful transforms
- Enhanced docstrings for
jit,grad,vmap,scan, and other transforms - Added examples showing state management patterns
- Documented state tracing behavior
- Clarified interaction with
StateTraceStack
- Enhanced docstrings for
Transform System
-
Enhanced State Finding: New
_find_state.pymodule for automatic state discovery- Improved state detection in nested structures
- Better handling of state dependencies
- Enhanced error messages for state-related issues
-
StatefulFunction: Major enhancements to
make_jaxprfunctionality- Improved Jaxpr generation for stateful computations
- Better handling of state read/write tracking
- Enhanced debugging support
-
Mapping Transformations: Significant refactoring of
vmapandpmap- Improved state management across vectorized operations
- Better handling of state broadcasting
- Enhanced error reporting for mapping operations
Random Number Generation
-
Module Reorganization: Complete refactoring of random module structure
- Renamed
_rand_funs.pyto_fun.py - Renamed
_rand_seed.pyto_seed.py - Renamed
_rand_state.pyto_state.py - Extracted distribution implementations to new
_impl.pymodule (691 lines)
- Renamed
-
Improved Random State: Enhanced
RandomStateclass with better state management- Simplified implementation (reduced from 534 to ~300 lines)
- Better integration with JAX's random number generation
- Improved thread safety and state isolation
Testing
- Comprehensive Test Suite: Added 102 tests covering all state functionality
- TestBasicState (13 tests): Core State class operations
- TestShortTermState (2 tests): Short-term state behavior
- TestLongTermState (2 tests): Long-term state behavior
- TestParamState (2 tests): Parameter state usage patterns
- TestBatchState (2 tests): Batch state functionality
- TestHiddenState (7 tests): Hidden state with different array types
- TestHiddenGroupState (9 tests): Multiple hidden state management
- TestHiddenTreeState (12 tests): PyTree hidden states with units
- TestFakeState (4 tests): Lightweight state alternative
- TestStateDictManager (6 tests): State collection management
- TestStateTraceStack (11 tests): State tracing and recovery
- TestTreefyState (6 tests): PyTree state references
- TestContextManagers (6 tests): State context managers
- TestStateCatcher (8 tests): State catching utilities
- TestIntegrationScenarios (5 tests): Real-world use cases
Bug Fixes
- Fixed
HiddenGroupState.set_value()to work correctly with JAX arrays - Improved error handling in hidden state value validation
- Enhanced type checking for hidden state initialization
Documentation
Tutorial Reorganization
-
Basics Tutorials: Complete rewrite and expansion
01_getting_started.ipynb: Enhanced introduction with practical examples02_state_management.ipynb: Comprehensive state management guide03_random_numbers.ipynb: In-depth random number generation tutorial
-
Neural Networks Tutorials: Restructured and expanded
01_module_basics.ipynb: New comprehensive module system guide02_basic_layers.ipynb: Enhanced layer documentation with examples03_activations_normalization.ipynb: Detailed activation and normalization guide04_recurrent_networks.ipynb: New RNN tutorial with practical examples05_dynamics_systems.ipynb: New dynamical systems tutorial
-
Examples: Reorganized and enhanced
- Renamed
10_image_classification.ipynbto01_image_classification.ipynb - Renamed
11_sequence_modeling.ipynbto02_sequence_modeling.ipynb - Added
03_brain_inspired_computing.ipynb: New brain-inspired computing examples - Renamed
18_optimization_tricks.ipynbto04_optimization_tricks.ipynb - Renamed
19_model_deployment.ipynbto05_model_deployment.ipynb
- Renamed
-
Transforms Tutorials: Reorganized for better flow
01_jit_compilation.ipynb: New comprehensive JIT guide02_automatic_differentiation.ipynb: Enhanced autodiff tutorial03_vectorization.ipynb: Improved vmap/pmap guide04_loops_conditions.ipynb: Enhanced control flow guide05_other_transforms.ipynb: Other transformation utilities
-
Advanced Tutorials: Renumbered for clarity
01_graph_operations.ipynb(formerly14_graph_operations.ipynb)02_mixin_system.ipynb(formerly15_mixin_system.ipynb)03_typing_system.ipynb(formerly16_typing_system.ipynb)04_utilities.ipynb(formerly17_utilities.ipynb)
-
Migration Guides: Updated and simplified
01_migration_from_pytorch.ipynb: Enhanced PyTorch migration guide- Removed outdated BrainPy integration notebook
-
Supplementary: Reorganized
01_performance_optimization.ipynb02_debugging_tips.ipynb03_faq.ipynb: Updated FAQ with new content
API Documentation
- Enhanced module documentation in
nn.rstwith 306 line improvements - Updated
transform.rstwith new transform APIs - Improved
environ.rstandgraph.rstdocumentation
Refactoring
- Removed deprecated
eval_shapemodule and tests - Removed deprecated
_random.pytransform module - Cleaned up unused imports across all modules
- Improved code organization in neural network layers
- Enhanced type hints and docstrings throughout
Infrastructure
- Added development dependency for tutorial generation
- Updated benchmark scripts for performance testing
- Improved test coverage across transformation modules
What's Changed
- update logo by @xinzhu-L in #102
- Refactor random API: Extract distributions and rename modules by @chaoming0625 in #103
- Enhance stateful JAX transforms and update tutorials by @chaoming0625 in #104
- Updates by @oujago in #105
- Docs by @oujago in #106
- Enhance HiddenState and add HiddenGroupState, HiddenTreeState by @chaoming0625 in #107
New Contributors
Full Changelog: v0.2.0...v0.2.2
Version 0.2.0
This is a major release with significant refactoring, new features, and comprehensive documentation improvements.
Breaking Changes
-
Module Deprecations: Deprecated
brainstate.augment,brainstate.compile, andbrainstate.functionalmodules in favor ofbrainstate.transformandbrainstate.nn- Added deprecation proxies to guide users towards replacement modules
- Updated all documentation and examples to use new module paths
-
State Management: Replaced
write_back_state_valueswithassign_state_vals_v2for improved state management -
Import Path Changes: Major refactoring of import paths across the codebase
- Moved initialization references to use
brainstate.nn - Updated random functions to use
brainstate.random - Standardized imports across all modules
- Moved initialization references to use
-
Type System: Implemented
JointTypesandOneOfTypesgeneric aliases to enhance type checking and avoid metaclass conflicts- Support for subscript syntax
- Improved type hints across modules
-
Copyright: Updated copyright notices to reflect new ownership by BrainX Ecosystem Limited
New Features
Neural Network Components
-
Transposed Convolution Layers: Complete implementations for upsampling operations
ConvTranspose1d,ConvTranspose2d,ConvTranspose3d- Support for both channels-first and channels-last data formats via
channel_firstparameter - Configurable stride for controllable upsampling factors
- Grouped transposed convolution support
- Automatic padding computation for 'SAME' and 'VALID' modes
-
Convolution Enhancements: Added support for both channels-first and channels-last data formats
- New
channel_firstboolean parameter (default:False) - PyTorch-compatible format (e.g.,
[B, C, H, W]) whenchannel_first=True - Default JAX-style format (e.g.,
[B, H, W, C]) whenchannel_first=False
- New
-
Padding Layers: Added padding layers for 1D, 2D, and 3D tensors with various modes
-
Unpooling Layers: Added
MaxUnpool1d,MaxUnpool2d, andMaxUnpool3dwithreturn_indicessupport -
Gradient Utilities: Implemented
clip_grad_normfunction for gradient clipping in PyTree structures -
Embedding Enhancements:
- Added
padding_idx,max_norm, andnorm_typeparameters - Improved gradient management with new
_contains_tracerfunction - Optimized max_norm application with accessed mask for scaling
- Added
-
BatchNorm Improvements: Added
feature_axisandtrack_running_statsparameters -
LoRA Layer: Added
in_sizeparameter for improved size handling -
Activation Functions: Added new activation functions and improved signatures
Transform & Compilation
-
StatefulMapping: Introduced for enhanced state management in vmap transformations
-
Mixin Classes: Added
Mode,JointMode,Batching, andTrainingclasses for computation behavior control -
Bounded Cache: Implemented thread-safe bounded cache for JAX Jaxpr with:
- Comprehensive validation
- Statistics tracking
- Enhanced error handling
-
Input Validation: Enhanced input size handling to support numpy integer types
-
Context Parameters: Update method now accepts additional context parameters for improved environment settings
Random & Initialization
-
Dependencies: Integrated
braintoolsfor initialization and surrogate gradient functions- Updated all initialization references
- Refactored to use
braintools.surrogatefor spike functions
-
Random Functions: Replaced
uniform_for_unitwithjr.uniformfor consistency and performance
Utilities & Infrastructure
-
Filter Utilities: Added comprehensive filter utilities for nested structures
-
Pretty Representation: Enhanced pretty_pytree module with:
- Comprehensive documentation
- Mapping functions
- JAX integration
-
Error Handling: Improved state length validation by replacing assertions with
ValueErrorexceptions -
Collective Operations: Updated function signatures to return target in collective operations
Documentation
-
Comprehensive Docstrings: Added detailed NumPy-style docstrings across all modules
- Full parameter descriptions with types and default values
- Multiple practical examples in code blocks
- Comparison sections highlighting differences from PyTorch
- Mathematical formulas where applicable
- References to original papers
- Best practices and use cases
-
New Documentation Pages:
brainstate.environmodule documentationbrainstate.transform(renamed from compile.rst)- Random number generation module
- Pretty representation module
- State management tutorial notebook
-
Enhanced Examples: Updated documentation examples to use interactive prompts for clarity
-
Module Descriptions: Enhanced documentation with detailed descriptions, key features, and usage examples
Testing
-
Comprehensive Test Coverage: Added extensive test suites for:
_BoundedCacheandStatefulFunctionbrainstate.mixinmodulebrainstate.environmodule (context management, precision settings, callbacks)- DeprecatedModule and proxy creation functionality
- Compatible import module
- Metrics module
- Node class and helper functions
- Activation functions with shape and gradient checks
- Dropout layers
- Surrogate gradient functions
- Filter utilities
- Struct module
- Pretty representation
-
Test Framework Updates: Refactored tests to use
absltestfor better JAX compatibility
Refactoring
-
File Reorganization:
- Renamed
metrics.pyto_metrics.py - Renamed
_rate_rnns.pyto_rnns.py - Renamed
_init.pytoinit.py - Reorganized graph module files
- Cleaned up unused imports and classes
- Renamed
-
Code Quality:
- Streamlined imports across all modules
- Enhanced code formatting and whitespace consistency
- Removed unnecessary inheritance and unused elements
- Simplified type annotations
- Improved method signatures for clarity
-
Neuron & Synapse Classes: Refactored to use brainpy module and updated initialization methods
-
Base Classes: Changed base class of
EINetandNetfromDynamicsGrouptoModulefor consistency -
Evaluation Functions: Refactored and updated method names for consistency
Infrastructure
-
Version Bump: Updated version to 0.2.0
-
Development Dependencies: Added
braintoolsto development requirements -
Issue Templates: Added bug report and feature request templates for improved issue tracking
-
CI/CD: Refactored CI configurations to update pip installation commands
-
Git Ignore: Updated to exclude example figures directory and build artifacts
Bug Fixes
- Enhanced delay handling for multi-dimensional inputs
- Fixed gradient function references
- Improved deprecation handling in tests
- Fixed precision checks in complex number handling
New Contributors
Full Changelog: v0.1.10...v0.2.0
Version 0.1.10
What's Changed
- Fix precision checks in _get_complex to use 'in' for list membership by @chaoming0625 in #96
- ⬆️ Bump actions/setup-python from 5 to 6 by @dependabot[bot] in #98
- ⬆️ Bump actions/download-artifact from 4 to 5 by @dependabot[bot] in #97
- Enhance StateWithDelay class with detailed documentation and add dela… by @chaoming0625 in #99
- Bump version to 0.1.10 by @chaoming0625 in #100
Full Changelog: v0.1.9...v0.1.10