Spherinator: Model Training
Spherinator provides representation learning using autoencoders to compress generic data to a low-dimensional latent space. The primary model is a Variational Autoencoder (VAE) with a spherical latent space based on a Power Spherical distribution, which is particularly well-suited for the interactive visualization of high-dimensional data such as images, spectra, point clouds, and cubes.
The encoder compresses the input into a location vector on the unit hypersphere and a concentration
scale. The decoder reconstructs the input from a sample drawn from that distribution. The
VariationalEncoder wraps any backbone encoder with two linear heads:
fc_location— maps the backbone output to a unit-normalized location vector of dimensionz_dim.fc_scale— maps the backbone output to a positive concentration scalar (viasoftplus + 1to avoid collapse).
Installation
Spherinator can be installed via pip:
pip install spherinator
Training the model
Training requires a YAML configuration file that specifies the data, model architecture, and training parameters. Multiple config files can be composed on the command line; later files override earlier ones.
spherinator fit -c config.yaml
Individual arguments can be overridden inline:
spherinator fit -c config.yaml \
--model.init_args.z_dim 16 \
--trainer.devices [0,1] \
--trainer.max_epochs 100
Configs can be composed by chaining multiple -c flags:
spherinator fit \
-c experiments/illustris.yaml \
-c experiments/vae_vit.yaml \
-c experiments/wandb.yaml
Training can be resumed from a checkpoint:
spherinator fit -c config.yaml --ckpt_path path/to/checkpoint.ckpt
DataModule
The DataModule loads data from Apache Parquet files, applies
optional per-column transforms, and feeds batches to the model. It is defined in the data section
of the YAML file.
data:
class_path: spherinator.data.DataModule
init_args:
path: ./data/illustris_SKIRT_synthetic_images/parquet-128
columns:
- name: data
transform:
class_path: torchvision.transforms.v2.Resize
init_args:
size: 224
return_dict: False
batch_size: 16
shuffle: True
num_workers: 4
Parameter |
Description |
|---|---|
|
Path to the directory containing Parquet files |
|
List of columns to load; each entry may carry a |
|
If |
|
Number of samples per mini-batch |
|
Shuffle the dataset each epoch |
|
Number of parallel data-loading workers |
Model architecture
The top-level model is defined in the model section of the YAML file.
Two model classes are available:
Class |
Description |
|---|---|
|
Deterministic autoencoder |
|
VAE with Power Spherical latent distribution |
Autoencoder
The deterministic autoencoder encodes directly to a fixed-size vector without sampling:
model:
class_path: spherinator.models.Autoencoder
init_args:
encoder:
class_path: ...
decoder:
class_path: ...
reconstruction_loss: torch.nn.MSELoss
VariationalAutoencoder
model:
class_path: spherinator.models.VariationalAutoencoder
init_args:
encoder:
class_path: ... # any encoder below
decoder:
class_path: ... # any decoder below
encoder_out_dim: 64 # must match encoder output_dim
z_dim: 3 # latent space dimensionality
beta: 1.0e-3 # KL weight (beta-VAE)
reconstruction_loss: torch.nn.MSELoss
Parameter |
Description |
|---|---|
|
Must match the |
|
Dimension of the spherical latent space; 3 maps to a sphere (S²) |
|
Scales the KL divergence term relative to the reconstruction loss |
|
Reconstruction loss |
|
Maximum concentration scale to prevent collapse |
Encoder architectures
The encoder architecture should be chosen to match the input data type.
ConvolutionalEncoder2D
Standard 2D CNN encoder built from ConsecutiveConv2DLayer blocks. Each block is a sequence of
LazyConv2d layers with optional batch normalization, activation, and pooling. The output is
flattened and projected to output_dim via a lazy linear layer.
encoder:
class_path: spherinator.models.ConvolutionalEncoder2D
init_args:
input_dim: [3, 128, 128]
output_dim: 64
cnn_layers:
- class_path: spherinator.models.ConsecutiveConv2DLayer
init_args:
kernel_size: 3
stride: 1
padding: 0
out_channels: [16, 20, 24]
- class_path: spherinator.models.ConsecutiveConv2DLayer
init_args:
kernel_size: 4
stride: 2
padding: 0
out_channels: [32, 64]
Each entry in out_channels adds one convolutional layer. The ConsecutiveConv2DLayer arguments
are:
Argument |
Description |
|---|---|
|
Kernel size for all layers in this block |
|
Stride for all layers in this block |
|
Padding for all layers in this block |
|
List of output channel counts; one layer per entry |
|
Activation function class (default: |
|
Normalization class (default: |
|
Optional pooling module appended after each layer |
ConvolutionalEncoder1D
Analogous to ConvolutionalEncoder2D but for 1D inputs such as spectra, using
ConsecutiveConv1DLayer with LazyConv1d layers.
HuggingFaceViTEncoder
Wraps any HuggingFace Vision Transformer. The CLS token from the last hidden state is optionally
projected to output_dim via a linear layer.
encoder:
class_path: spherinator.models.HuggingFaceViTEncoder
init_args:
model_name: google/vit-base-patch16-224
output_dim: 64
freeze: False
Argument |
Description |
|---|---|
|
HuggingFace model identifier; must be a ViT variant |
|
Output projection size; if |
|
If |
Decoder architectures
ConvolutionalDecoder2D
Transposed-convolution decoder. A linear layer re-shapes the latent vector to the seed spatial
tensor; ConsecutiveConvTranspose2DLayer blocks then upsample to the target resolution.
decoder:
class_path: spherinator.models.ConvolutionalDecoder2D
init_args:
input_dim: 3
output_dim: [3, 128, 128]
cnn_input_dim: [64, 28, 28]
cnn_layers:
- class_path: spherinator.models.ConsecutiveConvTranspose2DLayer
init_args:
kernel_size: 5
stride: 2
padding: 0
out_channels: [32]
- class_path: spherinator.models.ConsecutiveConvTranspose2DLayer
init_args:
kernel_size: 3
stride: 1
padding: 0
out_channels: [20, 16, 3]
activation: null
cnn_input_dim sets the shape [C, H, W] that the seed linear projection reshapes to. The overall
spatial path must reach output_dim[1:] through the stacked transpose-conv blocks.
UpsamplingDecoder2D
Bilinear-upsampling decoder that avoids the checkerboard artifacts of transposed convolutions. Each
_UpsampleBlock doubles the spatial resolution with a bilinear upsample followed by a 3×3
convolution, batch normalization, and ReLU. A final 1×1 convolution maps to the output channels,
followed by a sigmoid.
z → Linear → reshape (base_channels, seed_size, seed_size)
→ n × UpsampleBlock (2×)
→ 1×1 Conv → Sigmoid → output
The number of upsampling steps is inferred automatically to reach output_dim[1] from seed_size.
decoder:
class_path: spherinator.models.UpsamplingDecoder2D
init_args:
input_dim: 64
output_dim: [3, 224, 224]
base_channels: 512
seed_size: 7
Argument |
Description |
|---|---|
|
Latent vector size |
|
Target image shape |
|
Channel count at the spatial seed; halved at each upsampling step |
|
Spatial width/height of the seed feature map |
The default seed_size: 7 with five upsampling steps reaches 224×224
($7 \times 2^5 = 224$), which matches the ViT-Base patch grid.
ConvolutionalDecoder1D
Mirror of ConvolutionalDecoder2D for 1D outputs using ConsecutiveConvTranspose1DLayer.
Loss functions
The reconstruction_loss field of any model accepts any nn.Module. Spherinator ships
two purpose-built losses in addition to the standard PyTorch losses.
PerceptualLoss
spherinator.losses.PerceptualLoss computes a VGG-16 feature-matching loss for sharper,
perceptually more faithful image reconstructions. A frozen VGG-16 backbone is used as a fixed
feature extractor; the loss is the mean squared error between the activations of the chosen
intermediate layers for the reconstruction and the target.
Single-channel (grayscale) inputs are automatically broadcast to three channels before being passed through the network.
reconstruction_loss:
class_path: spherinator.losses.PerceptualLoss
init_args:
layers: [3, 8, 15] # VGG-16 layer indices to tap
weights: [1.0, 1.0, 1.0] # per-layer loss weights
Argument |
Description |
|---|---|
|
Indices into |
|
Scalar weight for each tapped layer’s MSE contribution (default: all |
The default layer indices correspond to the outputs of the first, second, and third ReLU blocks of VGG-16, capturing low-, mid-, and high-level features respectively.
CombinedLoss
spherinator.losses.CombinedLoss forms a weighted sum of multiple loss functions, making it
easy to blend pixel-level and perceptual objectives:
$$\mathcal{L} = \sum_i w_i \cdot \mathcal{L}_i(\hat{x},, x)$$
reconstruction_loss:
class_path: spherinator.losses.CombinedLoss
init_args:
losses:
- class_path: torch.nn.MSELoss
- class_path: spherinator.losses.PerceptualLoss
init_args:
layers: [3, 8, 15]
weights: [1.0, 1.0, 1.0]
factors: [1.0, 0.1]
Argument |
Description |
|---|---|
|
List of |
|
Scalar weight for each corresponding loss (must be the same length as |
Optimizer
Any PyTorch optimizer can be specified in the optimizer section:
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
Trainer
The trainer section maps directly to
Lightning’s Trainer:
trainer:
max_epochs: 500
accelerator: gpu
devices: auto
precision: bf16-mixed
enable_progress_bar: True
enable_model_summary: True
Common precision options: 32, 16-mixed, bf16-mixed.
Weights & Biases integration
Append experiments/wandb.yaml to enable W&B logging. Edit entity and
tags as needed:
# experiments/wandb.yaml
trainer:
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
log_model: True
entity: <your-wandb-entity>
tags:
- my_experiment
Callbacks
Callbacks are appended as additional YAML files or inline under trainer.callbacks.
Log reconstructions during training
experiments/callback_log_reconstructions.yaml logs a fixed set of sample reconstructions to W&B
after every validation epoch. Requires W&B to be configured.
# experiments/callback_log_reconstructions.yaml
trainer:
callbacks:
- class_path: spherinator.callbacks.LogReconstructionCallback
init_args:
samples: 6
KL annealing
spherinator.callbacks.KLAnnealing gradually ramps the KL-divergence weight beta of a
VariationalAutoencoder during training. Starting with a small (or zero) KL weight prevents
posterior collapse in the early epochs and allows the encoder to first learn a good reconstruction
before the regularisation pressure is increased.
The schedule supports cyclic annealing: the ramp from start to end is repeated n_cycles
times over the total number of epochs, which has been shown to improve latent-space utilisation.
The ratio parameter controls what fraction of each cycle is spent ramping; the remainder of the
cycle stays at end.
trainer:
callbacks:
- class_path: spherinator.callbacks.KLAnnealing
init_args:
start: 0.0
end: 1.0e-2
n_epochs: 500
n_cycles: 4
ratio: 0.5
Argument |
Description |
|---|---|
|
Initial value of |
|
Target value of |
|
Total training epochs; controls the overall period (default: |
|
Number of times the ramp is repeated (default: |
|
Fraction of each cycle spent linearly ramping from |
The current beta value is logged to the trainer as "beta" each epoch so it can be tracked
in W&B or TensorBoard.
Save the best model checkpoint
experiments/callback_best_model.yaml saves the checkpoint with the lowest train_loss:
# experiments/callback_best_model.yaml
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: train_loss
filename: "{epoch}-{train_loss:.2f}"
save_top_k: 1
mode: min
every_n_epochs: 1
Complete example configurations
VAE with CNN encoder (128×128 images)
spherinator fit \
-c experiments/vae_cnn3.yaml \
-c experiments/illustris_small.yaml \
-c experiments/wandb.yaml \
-c experiments/callback_best_model.yaml
VAE with Vision Transformer encoder (224×224 images)
spherinator fit \
-c experiments/vae_vit.yaml \
-c experiments/illustris_small.yaml \
-c experiments/wandb.yaml \
-c experiments/callback_log_reconstructions.yaml
Exporting to ONNX
A trained model can be exported to ONNX for deployment outside of PyTorch.
The export_onnx() function loads a checkpoint and its corresponding CLI config, and writes three
ONNX files to export_path:
File |
Description |
|---|---|
|
Backbone + sphere head → unit-normalised location vector |
|
Latent vector → reconstructed output |
|
End-to-end encode → decode pipeline |
All three exported graphs support a dynamic batch axis, so the exported models accept any batch size at inference time.