Skip to content

leggedrobotics/sru-pytorch-spatial-learning

Repository files navigation

Spatially-Enhanced Recurrent Units (SRU) - PyTorch Implementation

Paper Website

📌 Important Note: This repository contains the code for Appendix A: Training details for spatial-temporal memorization task from the paper "Spatially-enhanced recurrent memory for long-range mapless navigation via end-to-end reinforcement learning" (IJRR 2025). This is the sru-pytorch-spatial-learning repository referenced on the project website.

About This Repository

This repository provides the core PyTorch implementation of Spatially-Enhanced Recurrent Units (SRU) and the experimental validation code for evaluating spatial-temporal memorization capabilities. The controlled experiments in Appendix A demonstrate SRU's superior spatial memory performance compared to standard recurrent architectures (LSTM, GRU) and state-space models (Mamba-SSM, S4).

Scope of this repository:

  • ✅ Core SRU architecture implementations (SRU-LSTM, SRU-GRU)
  • ✅ Spatial-temporal memorization experiments (Appendix A)
  • ✅ Baseline comparisons with LSTM, GRU, Mamba-SSM, and S4
  • ✅ Training and evaluation scripts for controlled memorization tasks

Not included in this repository:

  • ❌ Complete end-to-end navigation system (see main paper)
  • ❌ Full perception pipeline and real-world deployment code
  • ❌ Large-scale reinforcement learning navigation experiments

For the complete navigation system and real-world deployment, please refer to the related repositories on the project website.

The Spatial Memory Challenge

The Problem: While standard RNNs (LSTM, GRU) excel at capturing temporal dependencies, our research reveals a critical limitation: they struggle to effectively perform spatial memorization—the ability to transform and integrate sequential observations from varying perspectives into coherent spatial representations.

Think of it like this:

  • Temporal Memorization (RNNs excel): Remembering the sequence "A → B → C" happened in that order
  • Spatial Memorization (RNNs struggle): Transforming "I saw landmark A from position 1, then from position 2" into understanding where landmark A actually is in space

Classical approaches achieve spatial registration through homogeneous transformations, but standard RNNs struggle to learn these spatial transformations implicitly, even when provided with ego-motion information.

SRU Architecture

The Solution: SRUs enhance standard LSTM/GRU with a simple spatial transformation gate using element-wise multiplication. The key innovation is the st ⊙ (...) operation, which enables the network to implicitly learn spatial transformations from ego-centric observations.

Network Architecture Overview

Network Architecture

Figure: Overall Training Pipeline and Network Structure

The network processes ego-motion transformations ($M_t^{t-1}$), landmark coordinates ($l_t^i$), and categorical labels ($c^i$) through MLP and RNN layers. The RNN (shown in red) is the focus of enhancement with spatial transformation gates. The network outputs predicted landmark coordinates ($l_T^i$) and categorical labels ($c^i$) through MLPs.

Implemented Models

  • SRU-LSTM: LSTM with Additive Spatial Transformation Gates
  • SRU-GRU: GRU with Additive Spatial Transformation Gates
  • SRU-LSTM-Gated: LSTM with Additive Transformation and Gated Refinement
  • Baselines: Standard LSTM, GRU, Mamba-SSM, and S4 (simplified implementation)
Click to see technical details

SRU-LSTM

st = Wxs·xt + bs                                    # Spatial transformation term
gt = tanh(st ⊙ (Wxg·xt + Whg·ht-1 + bg))          # Modified candidate gate
ct = ftct-1 + itgt                          # Cell state update
ht = ottanh(ct)                                 # Hidden state output

SRU-GRU

st = Wxs·xt + bs                                    # Spatial transformation term
h̃t = tanh(st ⊙ (Wxh·xt + Whh·(rtht-1) + bh))    # Modified candidate hidden
ht = (1-zt) ⊙ h̃t + ztht-1                      # Hidden state update

Repository Structure

sru-pytorch-spatial-learning/
├── network/                    # SRU and baseline implementations
│   ├── __init__.py
│   ├── lstm_sru.py            # LSTM_SRU implementation
│   ├── gru_sru.py             # GRU_SRU implementation
│   ├── vanilla_mamab.py       # MambaNet implementation (optional)
│   └── s4_utils/              # S4 utilities (simplified implementation)
├── setup.py                    # Package setup configuration
├── pyproject.toml             # Modern Python packaging
├── SETUP_GUIDE.md             # Installation and setup guide
├── IMPORT_EXAMPLES.md         # Detailed usage examples
├── QUICK_REFERENCE.md         # Quick lookup table
├── example_usage.py           # Runnable example demonstrating all networks
├── run_pointcloud.py          # Training/evaluation script
├── visualize_pointobs.py      # 3D visualization tool
├── dataloader/                # Dataset loaders
├── utils/                     # Helper functions
├── params/pointcloud.yaml     # Configuration file
├── data/cloud/                # Dataset storage
├── models/cloud/              # Model checkpoints
└── figures/cloud/             # Visualizations

Installation

Requirements: Python 3.8+, PyTorch 2.0+, CUDA 11.8+ (for GPU)

# Create environment
conda create -n sru python=3.10 -y
conda activate sru

# Install PyTorch with CUDA support
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124

# Install dependencies
pip install pypose mamba-ssm causal-conv1d matplotlib scipy scikit-learn wandb pyyaml

Note: Adjust cu124 to match your CUDA version. To use MambaNet, install mamba-ssm: pip install mamba-ssm. For troubleshooting, see the detailed installation guide.

Development Installation (for using/modifying the SRU networks)

# From the repository root
pip install -e .

Quick Start - Using the Networks

The package provides three SRU network implementations:

import torch
from network import LSTM_SRU, LSTM_SRU_Gate, GRU_SRU

# Create models
lstm_sru = LSTM_SRU(input_size=15, hidden_size=128, num_layers=2, batch_first=True)
lstm_sru_gate = LSTM_SRU_Gate(input_size=15, hidden_size=128, num_layers=2, batch_first=True)
gru_sru = GRU_SRU(input_size=15, hidden_size=128, num_layers=2, batch_first=True)

# Forward pass
x = torch.randn(4, 10, 15)  # (batch_size, seq_len, input_size)
output, state = lstm_sru(x)  # output: (4, 10, 128)

Available SRU Networks:

  • LSTM_SRU: LSTM with Additive Spatial Transformation Gates
  • LSTM_SRU_Gate: LSTM with Additive Transformation and Gated Refinement
  • GRU_SRU: GRU with Additive Spatial Transformation Gates

Note: MambaNet and S4 are included in the repository as baseline implementations for experimental comparisons, but are not part of the pip-installable package.

For detailed usage examples and patterns, see IMPORT_EXAMPLES.md.

Usage

Training

Train models on the point cloud prediction task:

# Train with default settings
python run_pointcloud.py --train

# Train with Weights & Biases logging
python run_pointcloud.py --train --wandb

Evaluation

Evaluate trained models:

python run_pointcloud.py

This will load the most recent checkpoint and visualize predictions.

Configuration

Edit params/pointcloud.yaml to modify:

  • Model architecture (hidden size, num layers)
  • Training hyperparameters (learning rate, batch size, epochs)
  • Dataset parameters (sequence length, rotation scale)
  • Optimizer settings

Experimental Tasks

This repository implements the spatial-temporal memorization experiments described in our paper:

Spatial-Temporal Memory Task

At each timestep t, the agent receives:

  • Landmark coordinates l_t^i ∈ ℝ³ in the robot's current frame
  • Binary categorical labels c^i associated with each landmark
  • Ego-motion transformation matrix M_{t-1→t} from previous to current frame

After T timesteps, the network must:

  1. Spatial Task: Transform and memorize all observed landmark coordinates into the final frame at t=T
  2. Temporal Task: Recall all binary labels in sequential order

Baseline Comparisons

We compare SRU units against:

  • Standard RNNs: LSTM, GRU
  • State-Space Models: Mamba-SSM, S4

Experimental Results

Key Findings

Metric Standard RNNs SRU Units
Temporal Memorization ✅ Converge quickly ✅ Converge quickly
Spatial Memorization ❌ Struggle to learn ✅ Learn effectively

Training Loss

The temporal task requires models to recall binary labels in sequential order, while the spatial task requires models to transform and memorize landmark coordinates. Below are the results showing SRU's superior spatial memorization capabilities:

Temporal Training Loss
Figure 1a: Temporal Task Loss
SRU and baseline RNNs show comparable convergence speed on sequential memorization.
Spatial Training Loss
Figure 1b: Spatial Task Loss
SRU achieves lower loss and faster convergence compared to LSTM, GRU, and state-space models.

Spatial Coordinate Mapping

Qualitative comparison of spatial coordinate predictions from LSTM and SRU models:

LSTM Spatial Mapping
Figure 2a: LSTM Spatial Mapping
Standard LSTM fails to accurately transform and register landmark coordinates across time steps.
SRU Spatial Mapping
Figure 2b: SRU Spatial Mapping
SRU effectively registers landmark positions, demonstrating superior spatial transformation capabilities.

Citation

If you use this code in your research, please cite:

@article{yang2025sru,
  author = {Yang, Fan and Frivik, Per and Hoeller, David and Wang, Chen and Cadena, Cesar and Hutter, Marco},
  title = {Spatially-enhanced recurrent memory for long-range mapless navigation via end-to-end reinforcement learning},
  journal = {The International Journal of Robotics Research},
  year = {2025},
  doi = {10.1177/02783649251401926},
  url = {https://doi.org/10.1177/02783649251401926}
}

Contact

For questions or issues, please contact:

License

MIT License

Copyright (c) 2024-2025 Fan Yang, Robotic Systems Lab, ETH Zurich

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Acknowledgments

This project builds upon:

About

Spatially-Enhanced Recurrent Units (SRU) - PyTorch Implementation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages