context-aware-flow-matching

Context-Aware Flow Matching

Authors: Lars Kühmichel

This site is a work in progress. Please check back later.

GitHub Repository

GitHub Pages Project Page

This repository contains parts of the code used in my master’s thesis, titled Advancements in Context-Aware Learning and Generative Modeling.

1. Introduction

Our approach to context-aware learning is defined in the thesis as deep learning using an embedding from a set of context inputs:

Context-Aware Learning

In this repo, we use Optimal Transport Flow Matching to leverage this embedding and learn a generative model that can be conditioned on sampled context embeddings, thus enabling interpolation between contexts:

<img src="docs/context-aware-flow-matching.webp" width=50% alt="Interpolation">

2. Install

Create a new conda environment with the required dependencies:

conda env create -f env.yaml

Activate the environment:

conda activate context-aware-flow-matching

Verify your install by running pytest:

pytest tests -m "not slow"

If you want to plot samples using blender, install the blender env instead:

conda env create -f blender.yaml

Activate and verify as above.

Note that these environments are incompatible with each other, because they each require different python versions.

Experiment notebooks can be found in the experiments folder. We use Lightning-Trainable to train our models. Each notebook contains the hyperparameters used for training.

3. ModelNet10

Dataset: ModelNet10

ModelNet10 Dataset Samples ModelNet10 Model Reconstructions
Left: Samples from the dataset. Right: Model Reconstructions.

ModelNet10 Model Samples

Random samples from the trained model.

ModelNet10 Context Interpolation

Linear interpolation between randomly sampled contextual embeddings.
Transition between the data and latent space.
Rotating points in the latent space of the flow helps visualize the shape manifold.

ModelNet10 Evaluation Metrics

Evaluation metrics on the test set. The model is competitive with other state-of-the-art approaches.

4. LIDAR-CS

This is not part of the thesis, but I may revisit this dataset in the future.

Dataset: LIDAR-CS

5. References

See my thesis: Advancements in Context-Aware Learning and Generative Modeling

6. Citation

If this repo is useful to you in your research, please cite my thesis and related work:

@mastersthesis{kuehmichel2024advancements,
    author={Lars Kühmichel},
    title={Advancements in Context-Aware Learning and Generative Modeling},
    school={Heidelberg University},
    year={2024},
    month={01},
    day={22},
}

@misc{müller2023contextaware,
      title={Towards Context-Aware Domain Generalization: Representing Environments with Permutation-Invariant Networks}, 
      author={Jens Müller and Lars Kühmichel and Martin Rohbeck and Stefan T. Radev and Ullrich Köthe},
      year={2023},
      eprint={2312.10107},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}