Replication of Text2Mol: Cross-Modal Molecular Retrieval with Natural Language Queries
Create a new conda environment for the project:
# Create conda environment
conda env create -f code/requirements.yaml
# Activate environment
conda activate text2mol
# Update environment (if needed)
conda env update -f code/requirements.yaml --pruneTrain the Text2Mol model:
python code/main.py --data data --output_path test_output --model MLP --epochs 40 --batch_size 32Rank embeddings and evaluate performance:
# Rank single model outputs
python code/ranker.py test_output/embeddings --train --val --test
# Rank ensemble of multiple models
python code/ensemble.py test_output/embeddings GCN_outputs/embeddings --train --val --testRun example queries with a trained model:
python code/test_example.py test_output/embeddings/ data/ test_output/CHECKPOINT.ptThe project uses the ChEBI-20 dataset located in the data/ directory. The dataset includes:
- Training/Validation/Test splits:
training.txt,val.txt,test.txt - Molecular graphs:
mol_graphs.zipcontaining graph representations - Token embeddings:
token_embedding_dict.npyfor molecular substructure tokens - Corpus data:
ChEBI_defintions_substructure_corpus.cpwith tokenized descriptions
Each data file contains:
- CID: PubChem Compound ID
- Mol2Vec embeddings: Pre-computed molecular embeddings
- ChEBI descriptions: Natural language descriptions of molecules
The implementation includes three model variants:
| Model | Description |
|---|---|
| MLP | Multi-layer perceptron for embedding projection |
| GCN | Graph Convolutional Network for molecular representation |
| Attention | Attention-based model for cross-modal learning |
| File | Purpose |
|---|---|
main.py |
Main training script |
models.py |
Model architecture definitions |
dataloaders.py |
Data loading and preprocessing |
losses.py |
Loss function implementations |
ranker.py |
Embedding ranking and evaluation |
ensemble.py |
Ensemble model evaluation |
extract_embeddings.py |
Embedding extraction utilities |
test_example.py |
Interactive testing interface |
ranker_threshold.py |
Threshold analysis and visualization |
python code/extract_embeddings.py \
--data data \
--output_path embedding_output_dir \
--checkpoint test_output/CHECKPOINT.pt \
--model MLP \
--batch_size 32python code/ranker_threshold.py test_output/embeddings \
--train --val --test \
--output_file threshold_analysis.pngWatch the project presentation: YouTube Video
Key dependencies include:
- PyTorch 1.11.0: Deep learning framework
- PyTorch Geometric: Graph neural networks
- Transformers 4.15.0: Pre-trained language models
- NumPy, Pandas: Data manipulation
- Matplotlib: Visualization
- Scikit-learn: Machine learning utilities
If you use this implementation in your research, please cite the original paper:
@inproceedings{edwards2021text2mol,
title={Text2Mol: Cross-Modal Molecule Retrieval with Natural Language Queries},
author={Edwards, Carl and Zhai, ChengXiang and Ji, Heng},
booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing},
pages={595--607},
year={2021},
url = {https://aclanthology.org/2021.emnlp-main.47/}
}This project is licensed under the MIT License - see the LICENSE file for details.