Knowledge graph link prediction framework combining:
- Structure encoder: relational message passing
- Feature transformer: LimiX or TabICL
conda create -n kgpfn python=3.12
conda activate kgpfn
pip install -r requirements.txtFlash Attention (required for TabICL/LimiX):
Download the prebuilt wheel matching your CUDA/PyTorch version from the flash-attention releases, then install:
# Example for CUDA 12.6 + PyTorch 2.7
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
pip install flash_attn-2.8.0.post2+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whlDownload structure encoder and feature transformer checkpoints if you want to retrain KGPFN:
# Download with TabICL as feature transformer (default)
python script/download.py --ft tabicl
# Download with LimiX instead
python script/download.py --ft limixOr directly download a fully pretrained KGPFN model checkpoint:
# tabicl as the pfn architecture (default)
python script/download.py --kgpfn
# limix as the pfn architecture
python script/download.py --kgpfn limix
# tabicl as the pfn architecture with semantic encoder of all-MiniLM-L12-v2
python script/download.py --kgpfn iclsemanticAll files will be saved to ./cache/ directory.
Edit config/script/train_all.yaml and set your dataset root:
dataset:
root: /path/to/your/kg-datasetsMulti-GPU training:
accelerate launch --num_processes 8 script/pretrain_pfn.py -c config/script/train_all.yaml --gpus [0,1,2,3,4,5,6,7]CUDA_VISIBLE_DEVICES=0 python script/test_kgpfn.py \
-c config/script/test.yaml \
--gpus [0]