# Spherinator: Model Training Spherinator provides representation learning using autoencoders to compress generic data to a low-dimensional latent space. The primary model is a Variational Autoencoder (VAE) with a **spherical latent space** based on a [Power Spherical](https://github.com/nicola-decao/power_spherical) distribution, which is particularly well-suited for the interactive visualization of high-dimensional data such as images, spectra, point clouds, and cubes. ```{figure} assets/vae.svg --- name: fig:vae align: center --- ``` The encoder compresses the input into a location vector on the unit hypersphere and a concentration scale. The decoder reconstructs the input from a sample drawn from that distribution. The `VariationalEncoder` wraps any backbone encoder with two linear heads: - **`fc_location`** — maps the backbone output to a unit-normalized location vector of dimension `z_dim`. - **`fc_scale`** — maps the backbone output to a positive concentration scalar (via `softplus + 1` to avoid collapse). ## Installation Spherinator can be installed via `pip`: ```bash pip install spherinator ``` ## Training the model Training requires a YAML configuration file that specifies the data, model architecture, and training parameters. Multiple config files can be composed on the command line; later files override earlier ones. ```bash spherinator fit -c config.yaml ``` Individual arguments can be overridden inline: ```bash spherinator fit -c config.yaml \ --model.init_args.z_dim 16 \ --trainer.devices [0,1] \ --trainer.max_epochs 100 ``` Configs can be **composed** by chaining multiple `-c` flags: ```bash spherinator fit \ -c experiments/illustris.yaml \ -c experiments/vae_vit.yaml \ -c experiments/wandb.yaml ``` Training can be resumed from a checkpoint: ```bash spherinator fit -c config.yaml --ckpt_path path/to/checkpoint.ckpt ``` ## DataModule The `DataModule` loads data from [Apache Parquet](https://parquet.apache.org/) files, applies optional per-column transforms, and feeds batches to the model. It is defined in the `data` section of the YAML file. ```yaml data: class_path: spherinator.data.DataModule init_args: path: ./data/illustris_SKIRT_synthetic_images/parquet-128 columns: - name: data transform: class_path: torchvision.transforms.v2.Resize init_args: size: 224 return_dict: False batch_size: 16 shuffle: True num_workers: 4 ``` | Parameter | Description | |-----------|-------------| | `path` | Path to the directory containing Parquet files | | `columns` | List of columns to load; each entry may carry a `transform` | | `return_dict` | If `True`, batches are dicts keyed by column name; if `False`, the first column tensor is returned directly | | `batch_size` | Number of samples per mini-batch | | `shuffle` | Shuffle the dataset each epoch | | `num_workers` | Number of parallel data-loading workers | ## Model architecture The top-level model is defined in the `model` section of the YAML file. Two model classes are available: | Class | Description | |-------|-------------| | `spherinator.models.Autoencoder` | Deterministic autoencoder | | `spherinator.models.VariationalAutoencoder` | VAE with Power Spherical latent distribution | ### Autoencoder The deterministic autoencoder encodes directly to a fixed-size vector without sampling: ```yaml model: class_path: spherinator.models.Autoencoder init_args: encoder: class_path: ... decoder: class_path: ... reconstruction_loss: torch.nn.MSELoss ``` ### VariationalAutoencoder ```yaml model: class_path: spherinator.models.VariationalAutoencoder init_args: encoder: class_path: ... # any encoder below decoder: class_path: ... # any decoder below encoder_out_dim: 64 # must match encoder output_dim z_dim: 3 # latent space dimensionality beta: 1.0e-3 # KL weight (beta-VAE) reconstruction_loss: torch.nn.MSELoss ``` | Parameter | Description | |-----------|-------------| | `encoder_out_dim` | Must match the `output_dim` of the encoder backbone | | `z_dim` | Dimension of the spherical latent space; 3 maps to a sphere (S²) | | `beta` | Scales the KL divergence term relative to the reconstruction loss | | `reconstruction_loss` | Reconstruction loss | | `max_scale` | Maximum concentration scale to prevent collapse | ## Encoder architectures The encoder architecture should be chosen to match the input data type. ### ConvolutionalEncoder2D Standard 2D CNN encoder built from `ConsecutiveConv2DLayer` blocks. Each block is a sequence of `LazyConv2d` layers with optional batch normalization, activation, and pooling. The output is flattened and projected to `output_dim` via a lazy linear layer. ```yaml encoder: class_path: spherinator.models.ConvolutionalEncoder2D init_args: input_dim: [3, 128, 128] output_dim: 64 cnn_layers: - class_path: spherinator.models.ConsecutiveConv2DLayer init_args: kernel_size: 3 stride: 1 padding: 0 out_channels: [16, 20, 24] - class_path: spherinator.models.ConsecutiveConv2DLayer init_args: kernel_size: 4 stride: 2 padding: 0 out_channels: [32, 64] ``` Each entry in `out_channels` adds one convolutional layer. The `ConsecutiveConv2DLayer` arguments are: | Argument | Description | |----------|-------------| | `kernel_size` | Kernel size for all layers in this block | | `stride` | Stride for all layers in this block | | `padding` | Padding for all layers in this block | | `out_channels` | List of output channel counts; one layer per entry | | `activation` | Activation function class (default: `nn.ReLU`) | | `norm` | Normalization class (default: `nn.BatchNorm2d`) | | `pooling` | Optional pooling module appended after each layer | ### ConvolutionalEncoder1D Analogous to `ConvolutionalEncoder2D` but for 1D inputs such as spectra, using `ConsecutiveConv1DLayer` with `LazyConv1d` layers. ### HuggingFaceViTEncoder Wraps any HuggingFace Vision Transformer. The CLS token from the last hidden state is optionally projected to `output_dim` via a linear layer. ```yaml encoder: class_path: spherinator.models.HuggingFaceViTEncoder init_args: model_name: google/vit-base-patch16-224 output_dim: 64 freeze: False ``` | Argument | Description | |----------|-------------| | `model_name` | HuggingFace model identifier; must be a ViT variant | | `output_dim` | Output projection size; if `null`, uses the model's hidden size | | `freeze` | If `True`, the ViT backbone weights are frozen | ## Decoder architectures ### ConvolutionalDecoder2D Transposed-convolution decoder. A linear layer re-shapes the latent vector to the seed spatial tensor; `ConsecutiveConvTranspose2DLayer` blocks then upsample to the target resolution. ```yaml decoder: class_path: spherinator.models.ConvolutionalDecoder2D init_args: input_dim: 3 output_dim: [3, 128, 128] cnn_input_dim: [64, 28, 28] cnn_layers: - class_path: spherinator.models.ConsecutiveConvTranspose2DLayer init_args: kernel_size: 5 stride: 2 padding: 0 out_channels: [32] - class_path: spherinator.models.ConsecutiveConvTranspose2DLayer init_args: kernel_size: 3 stride: 1 padding: 0 out_channels: [20, 16, 3] activation: null ``` `cnn_input_dim` sets the shape `[C, H, W]` that the seed linear projection reshapes to. The overall spatial path must reach `output_dim[1:]` through the stacked transpose-conv blocks. ### UpsamplingDecoder2D Bilinear-upsampling decoder that avoids the checkerboard artifacts of transposed convolutions. Each `_UpsampleBlock` doubles the spatial resolution with a bilinear upsample followed by a 3×3 convolution, batch normalization, and ReLU. A final 1×1 convolution maps to the output channels, followed by a sigmoid. ``` z → Linear → reshape (base_channels, seed_size, seed_size) → n × UpsampleBlock (2×) → 1×1 Conv → Sigmoid → output ``` The number of upsampling steps is inferred automatically to reach `output_dim[1]` from `seed_size`. ```yaml decoder: class_path: spherinator.models.UpsamplingDecoder2D init_args: input_dim: 64 output_dim: [3, 224, 224] base_channels: 512 seed_size: 7 ``` | Argument | Description | |----------|-------------| | `input_dim` | Latent vector size | | `output_dim` | Target image shape `[C, H, W]` | | `base_channels` | Channel count at the spatial seed; halved at each upsampling step | | `seed_size` | Spatial width/height of the seed feature map | The default `seed_size: 7` with five upsampling steps reaches 224×224 ($7 \times 2^5 = 224$), which matches the ViT-Base patch grid. ### ConvolutionalDecoder1D Mirror of `ConvolutionalDecoder2D` for 1D outputs using `ConsecutiveConvTranspose1DLayer`. ## Loss functions The `reconstruction_loss` field of any model accepts any `nn.Module`. Spherinator ships two purpose-built losses in addition to the standard PyTorch losses. ### PerceptualLoss `spherinator.losses.PerceptualLoss` computes a **VGG-16 feature-matching loss** for sharper, perceptually more faithful image reconstructions. A frozen VGG-16 backbone is used as a fixed feature extractor; the loss is the mean squared error between the activations of the chosen intermediate layers for the reconstruction and the target. Single-channel (grayscale) inputs are automatically broadcast to three channels before being passed through the network. ```yaml reconstruction_loss: class_path: spherinator.losses.PerceptualLoss init_args: layers: [3, 8, 15] # VGG-16 layer indices to tap weights: [1.0, 1.0, 1.0] # per-layer loss weights ``` | Argument | Description | |----------|-------------| | `layers` | Indices into `vgg16.features` at which activations are extracted (default: `[3, 8, 15]`) | | `weights` | Scalar weight for each tapped layer's MSE contribution (default: all `1.0`) | The default layer indices correspond to the outputs of the first, second, and third ReLU blocks of VGG-16, capturing low-, mid-, and high-level features respectively. ### CombinedLoss `spherinator.losses.CombinedLoss` forms a **weighted sum of multiple loss functions**, making it easy to blend pixel-level and perceptual objectives: $$\mathcal{L} = \sum_i w_i \cdot \mathcal{L}_i(\hat{x},\, x)$$ ```yaml reconstruction_loss: class_path: spherinator.losses.CombinedLoss init_args: losses: - class_path: torch.nn.MSELoss - class_path: spherinator.losses.PerceptualLoss init_args: layers: [3, 8, 15] weights: [1.0, 1.0, 1.0] factors: [1.0, 0.1] ``` | Argument | Description | |----------|-------------| | `losses` | List of `nn.Module` loss instances | | `factors` | Scalar weight for each corresponding loss (must be the same length as `losses`) | ## Optimizer Any PyTorch optimizer can be specified in the `optimizer` section: ```yaml optimizer: class_path: torch.optim.Adam init_args: lr: 0.001 ``` ## Trainer The `trainer` section maps directly to [Lightning's Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html): ```yaml trainer: max_epochs: 500 accelerator: gpu devices: auto precision: bf16-mixed enable_progress_bar: True enable_model_summary: True ``` Common precision options: `32`, `16-mixed`, `bf16-mixed`. ## Weights & Biases integration Append `experiments/wandb.yaml` to enable [W&B](https://wandb.ai) logging. Edit `entity` and `tags` as needed: ```yaml # experiments/wandb.yaml trainer: logger: class_path: lightning.pytorch.loggers.WandbLogger init_args: project: spherinator log_model: True entity: tags: - my_experiment ``` ## Callbacks Callbacks are appended as additional YAML files or inline under `trainer.callbacks`. ### Log reconstructions during training `experiments/callback_log_reconstructions.yaml` logs a fixed set of sample reconstructions to W&B after every validation epoch. Requires W&B to be configured. ```yaml # experiments/callback_log_reconstructions.yaml trainer: callbacks: - class_path: spherinator.callbacks.LogReconstructionCallback init_args: samples: 6 ``` ### KL annealing `spherinator.callbacks.KLAnnealing` gradually ramps the KL-divergence weight `beta` of a `VariationalAutoencoder` during training. Starting with a small (or zero) KL weight prevents posterior collapse in the early epochs and allows the encoder to first learn a good reconstruction before the regularisation pressure is increased. The schedule supports **cyclic annealing**: the ramp from `start` to `end` is repeated `n_cycles` times over the total number of epochs, which has been shown to improve latent-space utilisation. The `ratio` parameter controls what fraction of each cycle is spent ramping; the remainder of the cycle stays at `end`. ```yaml trainer: callbacks: - class_path: spherinator.callbacks.KLAnnealing init_args: start: 0.0 end: 1.0e-2 n_epochs: 500 n_cycles: 4 ratio: 0.5 ``` | Argument | Description | |----------|-------------| | `start` | Initial value of `beta` at the beginning of each cycle (default: `0.0`) | | `end` | Target value of `beta` at the end of the ramp (default: `1.0e-2`) | | `n_epochs` | Total training epochs; controls the overall period (default: `100`) | | `n_cycles` | Number of times the ramp is repeated (default: `1`, i.e. monotone schedule) | | `ratio` | Fraction of each cycle spent linearly ramping from `start` to `end` (default: `1.0`) | The current `beta` value is logged to the trainer as `"beta"` each epoch so it can be tracked in W&B or TensorBoard. ### Save the best model checkpoint `experiments/callback_best_model.yaml` saves the checkpoint with the lowest `train_loss`: ```yaml # experiments/callback_best_model.yaml trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: monitor: train_loss filename: "{epoch}-{train_loss:.2f}" save_top_k: 1 mode: min every_n_epochs: 1 ``` ## Complete example configurations ### VAE with CNN encoder (128×128 images) ```bash spherinator fit \ -c experiments/vae_cnn3.yaml \ -c experiments/illustris_small.yaml \ -c experiments/wandb.yaml \ -c experiments/callback_best_model.yaml ``` ### VAE with Vision Transformer encoder (224×224 images) ```bash spherinator fit \ -c experiments/vae_vit.yaml \ -c experiments/illustris_small.yaml \ -c experiments/wandb.yaml \ -c experiments/callback_log_reconstructions.yaml ``` ## Exporting to ONNX A trained model can be exported to [ONNX](https://onnx.ai/) for deployment outside of PyTorch. The {py:func}`~spherinator.models.export_onnx` function loads a checkpoint and its corresponding CLI config, and writes three ONNX files to `export_path`: | File | Description | |------|-------------| | `encoder.onnx` | Backbone + sphere head → unit-normalised location vector | | `decoder.onnx` | Latent vector → reconstructed output | | `reconstruction.onnx` | End-to-end encode → decode pipeline | All three exported graphs support a dynamic batch axis, so the exported models accept any batch size at inference time.