Spherinator: Model Training
Spherinator provides representation learning using autoencoders to compress generic data to a low-dimensional latent space. The primary model is a Variational Autoencoder (VAE) with a spherical latent space based on a Power Spherical distribution, which is particularly well-suited for the interactive visualization of high-dimensional data such as images, spectra, point clouds, and cubes.
The encoder compresses the input into a location vector on the unit hypersphere and a concentration
scale. The decoder reconstructs the input from a sample drawn from that distribution. The
VariationalEncoder wraps any backbone encoder with two linear heads:
fc_location— maps the backbone output to a unit-normalized location vector of dimensionz_dim.fc_scale— maps the backbone output to a positive concentration scalar (viasoftplus + 1to avoid collapse).
Installation
Spherinator can be installed via pip:
pip install spherinator
Training the model
Training requires a YAML configuration file that specifies the data, model architecture, and training parameters. Multiple config files can be composed on the command line; later files override earlier ones.
spherinator fit -c config.yaml
Individual arguments can be overridden inline:
spherinator fit -c config.yaml \
--model.init_args.z_dim 16 \
--trainer.devices [0,1] \
--trainer.max_epochs 100
Configs can be composed by chaining multiple -c flags:
spherinator fit \
-c experiments/illustris.yaml \
-c experiments/vae_vit.yaml \
-c experiments/wandb.yaml
Training can be resumed from a checkpoint:
spherinator fit -c config.yaml --ckpt_path path/to/checkpoint.ckpt
DataModule
The DataModule loads data from Apache Parquet files, applies
optional per-column transforms, and feeds batches to the model. It is defined in the data section
of the YAML file.
data:
class_path: spherinator.data.DataModule
init_args:
path: ./data/illustris_SKIRT_synthetic_images/parquet-128
columns:
- name: data
transform:
class_path: torchvision.transforms.v2.Resize
init_args:
size: 224
return_dict: False
dict_kwargs:
batch_size: 16
shuffle: True
num_workers: 4
Parameter |
Description |
|---|---|
|
Path to the directory containing Parquet files |
|
List of columns to load; each entry may carry a |
|
If |
|
Number of samples per mini-batch |
|
Shuffle the dataset each epoch |
|
Number of parallel data-loading workers |
Model architecture
The top-level model is defined in the model section of the YAML file.
Two model classes are available:
Class |
Description |
|---|---|
|
Deterministic autoencoder |
|
VAE with Power Spherical latent distribution |
Autoencoder
The deterministic autoencoder encodes directly to a fixed-size vector without sampling:
model:
class_path: spherinator.models.Autoencoder
init_args:
encoder:
class_path: ...
decoder:
class_path: ...
VariationalAutoencoder
model:
class_path: spherinator.models.VariationalAutoencoder
init_args:
encoder:
class_path: ... # any encoder below
decoder:
class_path: ... # any decoder below
encoder_out_dim: 64 # must match encoder output_dim
z_dim: 3 # latent space dimensionality
beta: 1.0e-3 # KL weight (beta-VAE)
loss: MSE # MSE | NLL-normal | NLL-truncated | KL
fixed_scale: null # fix concentration (null = learnable)
Parameter |
Description |
|---|---|
|
Must match the |
|
Dimension of the spherical latent space; 3 maps to a sphere (S²) |
|
Scales the KL divergence term relative to the reconstruction loss |
|
Reconstruction loss: |
|
If set to a float, the concentration is frozen at that value |
Encoder architectures
The encoder architecture should be chosen to match the input data type.
ConvolutionalEncoder2D
Standard 2D CNN encoder built from ConsecutiveConv2DLayer blocks. Each block is a sequence of
LazyConv2d layers with optional batch normalization, activation, and pooling. The output is
flattened and projected to output_dim via a lazy linear layer.
encoder:
class_path: spherinator.models.ConvolutionalEncoder2D
init_args:
input_dim: [3, 128, 128]
output_dim: 64
cnn_layers:
- class_path: spherinator.models.ConsecutiveConv2DLayer
init_args:
kernel_size: 3
stride: 1
padding: 0
out_channels: [16, 20, 24]
- class_path: spherinator.models.ConsecutiveConv2DLayer
init_args:
kernel_size: 4
stride: 2
padding: 0
out_channels: [32, 64]
Each entry in out_channels adds one convolutional layer. The ConsecutiveConv2DLayer arguments
are:
Argument |
Description |
|---|---|
|
Kernel size for all layers in this block |
|
Stride for all layers in this block |
|
Padding for all layers in this block |
|
List of output channel counts; one layer per entry |
|
Activation function class (default: |
|
Normalization class (default: |
|
Optional pooling module appended after each layer |
ConvolutionalEncoder1D
Analogous to ConvolutionalEncoder2D but for 1D inputs such as spectra, using
ConsecutiveConv1DLayer with LazyConv1d layers.
HuggingFaceViTEncoder
Wraps any HuggingFace Vision Transformer. The CLS token from the last hidden state is optionally
projected to output_dim via a linear layer.
encoder:
class_path: spherinator.models.HuggingFaceViTEncoder
init_args:
model_name: google/vit-base-patch16-224
output_dim: 64
freeze: False
Argument |
Description |
|---|---|
|
HuggingFace model identifier; must be a ViT variant |
|
Output projection size; if |
|
If |
Decoder architectures
ConvolutionalDecoder2D
Transposed-convolution decoder. A linear layer re-shapes the latent vector to the seed spatial
tensor; ConsecutiveConvTranspose2DLayer blocks then upsample to the target resolution.
decoder:
class_path: spherinator.models.ConvolutionalDecoder2D
init_args:
input_dim: 3
output_dim: [3, 128, 128]
cnn_input_dim: [64, 28, 28]
cnn_layers:
- class_path: spherinator.models.ConsecutiveConvTranspose2DLayer
init_args:
kernel_size: 5
stride: 2
padding: 0
out_channels: [32]
- class_path: spherinator.models.ConsecutiveConvTranspose2DLayer
init_args:
kernel_size: 3
stride: 1
padding: 0
out_channels: [20, 16, 3]
activation: null
cnn_input_dim sets the shape [C, H, W] that the seed linear projection reshapes to. The overall
spatial path must reach output_dim[1:] through the stacked transpose-conv blocks.
UpsamplingDecoder2D
Bilinear-upsampling decoder that avoids the checkerboard artifacts of transposed convolutions. Each
_UpsampleBlock doubles the spatial resolution with a bilinear upsample followed by a 3×3
convolution, batch normalization, and ReLU. A final 1×1 convolution maps to the output channels,
followed by a sigmoid.
z → Linear → reshape (base_channels, seed_size, seed_size)
→ n × UpsampleBlock (2×)
→ 1×1 Conv → Sigmoid → output
The number of upsampling steps is inferred automatically to reach output_dim[1] from seed_size.
decoder:
class_path: spherinator.models.UpsamplingDecoder2D
init_args:
input_dim: 64
output_dim: [3, 224, 224]
base_channels: 512
seed_size: 7
Argument |
Description |
|---|---|
|
Latent vector size |
|
Target image shape |
|
Channel count at the spatial seed; halved at each upsampling step |
|
Spatial width/height of the seed feature map |
The default seed_size: 7 with five upsampling steps reaches 224×224
($7 \times 2^5 = 224$), which matches the ViT-Base patch grid.
ConvolutionalDecoder1D
Mirror of ConvolutionalDecoder2D for 1D outputs using ConsecutiveConvTranspose1DLayer.
Optimizer
Any PyTorch optimizer can be specified in the optimizer section:
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001
Trainer
The trainer section maps directly to
Lightning’s Trainer:
trainer:
max_epochs: 500
accelerator: gpu
devices: auto
precision: bf16-mixed
enable_progress_bar: True
enable_model_summary: True
Common precision options: 32, 16-mixed, bf16-mixed.
Weights & Biases integration
Append experiments/wandb.yaml to enable W&B logging. Edit entity and
tags as needed:
# experiments/wandb.yaml
trainer:
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
log_model: True
entity: <your-wandb-entity>
tags:
- my_experiment
Callbacks
Callbacks are appended as additional YAML files or inline under trainer.callbacks.
Log reconstructions during training
experiments/callback_log_reconstructions.yaml logs a fixed set of sample reconstructions to W&B
after every validation epoch. Requires W&B to be configured.
# experiments/callback_log_reconstructions.yaml
trainer:
callbacks:
- class_path: spherinator.callbacks.LogReconstructionCallback
init_args:
samples: 6
Save the best model checkpoint
experiments/callback_best_model.yaml saves the checkpoint with the lowest train_loss:
# experiments/callback_best_model.yaml
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: train_loss
filename: "{epoch}-{train_loss:.2f}"
save_top_k: 1
mode: min
every_n_epochs: 1
Complete example configurations
VAE with CNN encoder (128×128 images)
spherinator fit \
-c experiments/vae_cnn3.yaml \
-c experiments/illustris_small.yaml \
-c experiments/wandb.yaml \
-c experiments/callback_best_model.yaml
VAE with Vision Transformer encoder (224×224 images)
spherinator fit \
-c experiments/vae_vit.yaml \
-c experiments/illustris_small.yaml \
-c experiments/wandb.yaml \
-c experiments/callback_log_reconstructions.yaml