diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..7a60b85e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +*.pyc diff --git a/bindcraft.py b/bindcraft.py index 44eca693..cb4ddbef 100644 --- a/bindcraft.py +++ b/bindcraft.py @@ -67,6 +67,7 @@ script_start_time = time.time() trajectory_n = 1 accepted_designs = 0 +trajectory_runtime = TrajectoryDesignRuntime(advanced_settings) ### start design loop while True: @@ -108,7 +109,8 @@ ### Begin binder hallucination trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"], target_settings["target_hotspot_residues"], length, seed, helicity_value, - design_models, advanced_settings, design_paths, failure_csv) + design_models, advanced_settings, design_paths, failure_csv, + runtime=trajectory_runtime) trajectory_metrics = copy_dict(trajectory._tmp["best"]["aux"]["log"]) # contains plddt, ptm, i_ptm, pae, i_pae trajectory_pdb = os.path.join(design_paths["Trajectory"], design_name + ".pdb") @@ -174,6 +176,7 @@ ### MPNN redesign of starting binder mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings) + trajectory_runtime.invalidate() existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values) # create set of MPNN sequences with allowed amino acid composition diff --git a/functions/colabdesign_utils.py b/functions/colabdesign_utils.py index 4bb1db52..da4354a8 100644 --- a/functions/colabdesign_utils.py +++ b/functions/colabdesign_utils.py @@ -17,53 +17,182 @@ from .pyrosetta_utils import pr_relax, align_pdbs from .generic_utils import update_failures -# hallucinate a binder -def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv): - model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb") +_DEFAULT_TRAJECTORY_RUNTIME = None + + +class TrajectoryDesignRuntime: + """Reusable AF2 trajectory runtime to avoid per-trajectory model rebuilds.""" + + def __init__(self, advanced_settings): + self.advanced_settings = advanced_settings + self.af_model = None + self._rg_loss_added = False + self._iptm_loss_added = False + self._termini_loss_added = False + self._helix_loss_added = False + + def _ensure_model(self): + """ + (Re)initialize the AF2 design model if needed. + + The model is rebuilt when ``af_model`` is ``None`` – either on first + use or after an explicit ``invalidate()`` call. External calls to + ``clear_mem()`` (e.g. during MPNN validation) free the underlying + JAX device buffers, so callers **must** call ``invalidate()`` after + any such ``clear_mem()`` invocation to ensure a fresh model is + created on the next trajectory. + """ + if self.af_model is not None: + return + + clear_mem() + self.af_model = mk_afdesign_model( + protocol="binder", + debug=False, + data_dir=self.advanced_settings["af_params_dir"], + use_multimer=self.advanced_settings["use_multimer_design"], + num_recycles=self.advanced_settings["num_recycles_design"], + best_metric="loss", + ) + print("Initialised reusable AF2 trajectory model.") + + def invalidate(self): + """Mark the cached model as stale so ``_ensure_model`` rebuilds it. + + Call this after any external ``clear_mem()`` invocation that may + have freed the JAX device buffers backing ``self.af_model``. + """ + self.af_model = None + self._rg_loss_added = False + self._iptm_loss_added = False + self._termini_loss_added = False + self._helix_loss_added = False + + def prepare_trajectory( + self, + starting_pdb, + chain, + target_hotspot_residues, + length, + seed, + helicity_value, + ): + self._ensure_model() + + hotspot = None if target_hotspot_residues == "" else target_hotspot_residues + + self.af_model.prep_inputs( + pdb_filename=starting_pdb, + chain=chain, + binder_len=length, + hotspot=hotspot, + seed=seed, + rm_aa=self.advanced_settings["omit_AAs"], + rm_target_seq=self.advanced_settings["rm_template_seq_design"], + rm_target_sc=self.advanced_settings["rm_template_sc_design"], + ) + # Reset recycles for each run in case prior runs changed it (e.g. 4stage beta optimisation). + self.af_model.set_opt(num_recycles=self.advanced_settings["num_recycles_design"]) + + self.af_model.opt["weights"].update( + { + "pae": self.advanced_settings["weights_pae_intra"], + "plddt": self.advanced_settings["weights_plddt"], + "i_pae": self.advanced_settings["weights_pae_inter"], + "con": self.advanced_settings["weights_con_intra"], + "i_con": self.advanced_settings["weights_con_inter"], + } + ) + + # Redefine intramolecular contacts (con) and intermolecular contacts (i_con) definitions. + self.af_model.opt["con"].update( + { + "num": self.advanced_settings["intra_contact_number"], + "cutoff": self.advanced_settings["intra_contact_distance"], + "binary": False, + "seqsep": 9, + } + ) + self.af_model.opt["i_con"].update( + { + "num": self.advanced_settings["inter_contact_number"], + "cutoff": self.advanced_settings["inter_contact_distance"], + "binary": False, + } + ) - # clear GPU memory for new trajectory - clear_mem() + if self.advanced_settings["use_rg_loss"]: + if not self._rg_loss_added: + add_rg_loss(self.af_model, self.advanced_settings["weights_rg"]) + self._rg_loss_added = True + else: + self.af_model.opt["weights"]["rg"] = self.advanced_settings["weights_rg"] + elif self._rg_loss_added: + self.af_model.opt["weights"]["rg"] = 0.0 + + if self.advanced_settings["use_i_ptm_loss"]: + if not self._iptm_loss_added: + add_i_ptm_loss(self.af_model, self.advanced_settings["weights_iptm"]) + self._iptm_loss_added = True + else: + self.af_model.opt["weights"]["i_ptm"] = self.advanced_settings["weights_iptm"] + elif self._iptm_loss_added: + self.af_model.opt["weights"]["i_ptm"] = 0.0 + + if self.advanced_settings["use_termini_distance_loss"]: + if not self._termini_loss_added: + add_termini_distance_loss( + self.af_model, self.advanced_settings["weights_termini_loss"] + ) + self._termini_loss_added = True + else: + self.af_model.opt["weights"]["NC"] = self.advanced_settings[ + "weights_termini_loss" + ] + elif self._termini_loss_added: + self.af_model.opt["weights"]["NC"] = 0.0 + + if not self._helix_loss_added: + add_helix_loss(self.af_model, helicity_value) + self._helix_loss_added = True + else: + self.af_model.opt["weights"]["helix"] = helicity_value - # initialise binder hallucination model - af_model = mk_afdesign_model(protocol="binder", debug=False, data_dir=advanced_settings["af_params_dir"], - use_multimer=advanced_settings["use_multimer_design"], num_recycles=advanced_settings["num_recycles_design"], - best_metric='loss') - - # sanity check for hotspots - if target_hotspot_residues == "": - target_hotspot_residues = None - - af_model.prep_inputs(pdb_filename=starting_pdb, chain=chain, binder_len=length, hotspot=target_hotspot_residues, seed=seed, rm_aa=advanced_settings["omit_AAs"], - rm_target_seq=advanced_settings["rm_template_seq_design"], rm_target_sc=advanced_settings["rm_template_sc_design"]) - - ### Update weights based on specified settings - af_model.opt["weights"].update({"pae":advanced_settings["weights_pae_intra"], - "plddt":advanced_settings["weights_plddt"], - "i_pae":advanced_settings["weights_pae_inter"], - "con":advanced_settings["weights_con_intra"], - "i_con":advanced_settings["weights_con_inter"], - }) - - # redefine intramolecular contacts (con) and intermolecular contacts (i_con) definitions - af_model.opt["con"].update({"num":advanced_settings["intra_contact_number"],"cutoff":advanced_settings["intra_contact_distance"],"binary":False,"seqsep":9}) - af_model.opt["i_con"].update({"num":advanced_settings["inter_contact_number"],"cutoff":advanced_settings["inter_contact_distance"],"binary":False}) - + return self.af_model - ### additional loss functions - if advanced_settings["use_rg_loss"]: - # radius of gyration loss - add_rg_loss(af_model, advanced_settings["weights_rg"]) - if advanced_settings["use_i_ptm_loss"]: - # interface pTM loss - add_i_ptm_loss(af_model, advanced_settings["weights_iptm"]) +# hallucinate a binder +def binder_hallucination( + design_name, + starting_pdb, + chain, + target_hotspot_residues, + length, + seed, + helicity_value, + design_models, + advanced_settings, + design_paths, + failure_csv, + runtime=None, +): + global _DEFAULT_TRAJECTORY_RUNTIME + + model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb") - if advanced_settings["use_termini_distance_loss"]: - # termini distance loss - add_termini_distance_loss(af_model, advanced_settings["weights_termini_loss"]) + if runtime is None: + if _DEFAULT_TRAJECTORY_RUNTIME is None: + _DEFAULT_TRAJECTORY_RUNTIME = TrajectoryDesignRuntime(advanced_settings) + runtime = _DEFAULT_TRAJECTORY_RUNTIME - # add the helicity loss - add_helix_loss(af_model, helicity_value) + af_model = runtime.prepare_trajectory( + starting_pdb=starting_pdb, + chain=chain, + target_hotspot_residues=target_hotspot_residues, + length=length, + seed=seed, + helicity_value=helicity_value, + ) # calculate the number of mutations to do based on the length of the protein greedy_tries = math.ceil(length * (advanced_settings["greedy_percentage"] / 100)) @@ -102,6 +231,8 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu # if best iteration has high enough confidence then continue if initial_plddt > 0.65: print("Initial trajectory pLDDT good, continuing: "+str(initial_plddt)) + soft_iterations = advanced_settings["soft_iterations"] + temporary_iterations = advanced_settings["temporary_iterations"] if advanced_settings["optimise_beta"]: # temporarily dump model to assess secondary structure af_model.save_pdb(model_pdb_path) @@ -110,13 +241,13 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu # if beta sheeted trajectory is detected then choose to optimise if float(beta) > 15: - advanced_settings["soft_iterations"] = advanced_settings["soft_iterations"] + advanced_settings["optimise_beta_extra_soft"] - advanced_settings["temporary_iterations"] = advanced_settings["temporary_iterations"] + advanced_settings["optimise_beta_extra_temp"] + soft_iterations = soft_iterations + advanced_settings["optimise_beta_extra_soft"] + temporary_iterations = temporary_iterations + advanced_settings["optimise_beta_extra_temp"] af_model.set_opt(num_recycles=advanced_settings["optimise_beta_recycles_design"]) print("Beta sheeted trajectory detected, optimising settings") # how many logit iterations left - logits_iter = advanced_settings["soft_iterations"] - 50 + logits_iter = soft_iterations - 50 if logits_iter > 0: print("Stage 1: Additional Logits Optimisation") af_model.clear_best() @@ -129,10 +260,10 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu logit_plddt = initial_plddt # perform softmax trajectory design - if advanced_settings["temporary_iterations"] > 0: + if temporary_iterations > 0: print("Stage 2: Softmax Optimisation") af_model.clear_best() - af_model.design_soft(advanced_settings["temporary_iterations"], e_temp=1e-2, models=design_models, num_models=1, + af_model.design_soft(temporary_iterations, e_temp=1e-2, models=design_models, num_models=1, sample_models=advanced_settings["sample_models"], ramp_recycles=False, save_best=True) softmax_plddt = get_best_plddt(af_model, length) else: diff --git a/notebooks/BindCraft.ipynb b/notebooks/BindCraft.ipynb index a3586492..9233b179 100644 --- a/notebooks/BindCraft.ipynb +++ b/notebooks/BindCraft.ipynb @@ -1,893 +1,896 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "MTagjooX0XC1" - }, - "source": [ - "# BindCraft: Protein binder design\n", - "\n", - "\n", - "\n", - "Simple binder design pipeline using AlphaFold2 backpropagation, MPNN, and PyRosetta. Select your target and let the script do the rest of the work and finish once you have enough designs to order!\n", - "\n", - "The designs will be saved on your Google Drive under BindCraft/[design_name]/ and you can continue running the design pipeline if the session times out and it will continue adding new designs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "7fMzl8JiyaXm" - }, - "outputs": [], - "source": [ - "#@title Installation\n", - "%%time\n", - "import os, time, gc, io\n", - "import contextlib\n", - "import json\n", - "from datetime import datetime\n", - "from ipywidgets import HTML, VBox\n", - "from IPython.display import display\n", - "\n", - "if not os.path.isfile(\"bindcraft/params/done.txt\"):\n", - " print(\"Installing required BindCraft components\")\n", - "\n", - " print(\"Pulling BindCraft code from Github\")\n", - " os.makedirs('/content/bindcraft/', exist_ok=True)\n", - " !git clone https://github.com/martinpacesa/BindCraft /content/bindcraft/\n", - " os.system(\"chmod +x /content/bindcraft/functions/dssp\")\n", - " os.system(\"chmod +x /content/bindcraft/functions/DAlphaBall.gcc\")\n", - "\n", - " print(\"Installing ColabDesign\")\n", - " os.system(\"(mkdir bindcraft/params; apt-get install aria2 -qq; \\\n", - " aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n", - " tar -xf alphafold_params_2022-12-06.tar -C bindcraft/params; touch bindcraft/params/done.txt )&\")\n", - " os.system(\"pip install git+https://github.com/sokrypton/ColabDesign.git\")\n", - " # for debugging purposes\n", - " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n", - "\n", - " print(\"Installing PyRosetta\")\n", - " os.system(\"pip install pyrosetta_installer\")\n", - " with contextlib.redirect_stdout(io.StringIO()):\n", - " import pyrosetta_installer\n", - " pyrosetta_installer.install_pyrosetta(serialization=True)\n", - "\n", - " # download params\n", - " if not os.path.isfile(\"bindcraft/params/done.txt\"):\n", - " print(\"downloading AlphaFold params\")\n", - " while not os.path.isfile(\"bindcraft/params/done.txt\"):\n", - " time.sleep(5)\n", - "\n", - " print(\"BindCraft installation is finished, ready to run!\")\n", - "else:\n", - " print(\"BindCraft components already installed, ready to run!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "01IH64-ycCQY" - }, - "outputs": [], - "source": [ - "#@title Mount your Google Drive to save design results\n", - "from google.colab import drive\n", - "drive.mount('/content/drive')\n", - "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", - "print(f\"Google drive mounted at: {currenttime}\")\n", - "\n", - "bindcraft_google_drive = '/content/drive/My Drive/BindCraft/'\n", - "os.makedirs(bindcraft_google_drive, exist_ok=True)\n", - "print(\"BindCraft folder successfully created in your drive!\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "RbL-S_t2hicj" - }, - "outputs": [], - "source": [ - "#@title Binder design settings\n", - "# @markdown ---\n", - "# @markdown Enter path where to save your designs. We recommend to save on Google drive so that you can continue generating at any time.\n", - "design_path = \"/content/drive/MyDrive/BindCraft/PDL1/\" # @param {\"type\":\"string\",\"placeholder\":\"/content/drive/MyDrive/BindCraft/PDL1/\"}\n", - "\n", - "# @markdown Enter the name that should be prefixed to your binders (generally target name).\n", - "binder_name = \"PDL1\" # @param {\"type\":\"string\",\"placeholder\":\"PDL1\"}\n", - "\n", - "# @markdown The path to the .pdb structure of your target. Can be an experimental or AlphaFold2 structure. We recommend trimming the structure to as small as needed, as the whole selected chains will be backpropagated through the network and can significantly increase running times.\n", - "starting_pdb = \"/content/bindcraft/example/PDL1.pdb\" # @param {\"type\":\"string\",\"placeholder\":\"/content/bindcraft/example/PDL1.pdb\"}\n", - "\n", - "# @markdown Which chains of your PDB to target? Can be one or multiple, in a comma-separated format. Other chains will be ignored during design.\n", - "chains = \"A\" # @param {\"type\":\"string\",\"placeholder\":\"A,C\"}\n", - "\n", - "# @markdown What positions to target in your protein of interest? For example `1,2-10` or chain specific `A1-10,B1-20` or entire chains `A`. If left blank, an appropriate site will be selected by the pipeline.\n", - "target_hotspot_residues = \"\" # @param {\"type\":\"string\",\"placeholder\":\"\"}\n", - "\n", - "# @markdown What is the minimum and maximum size of binders you want to design? Pipeline will randomly sample different sizes between these values.\n", - "lengths = \"60,140\" # @param {\"type\":\"string\",\"placeholder\":\"70,150\"}\n", - "\n", - "# @markdown How many binder designs passing filters do you require?\n", - "number_of_final_designs = 100 # @param {\"type\":\"integer\",\"placeholder\":\"100\"}\n", - "# @markdown ---\n", - "# @markdown Enter path on your Google drive (/content/drive/MyDrive/BindCraft/[binder_name].json) to previous target settings to continue design campaign. If left empty, it will use the settings above and generate a new settings json in your design output folder.\n", - "load_previous_target_settings = \"\" # @param {\"type\":\"string\",\"placeholder\":\"\"}\n", - "# @markdown ---\n", - "\n", - "if load_previous_target_settings:\n", - " target_settings_path = load_previous_target_settings\n", - "else:\n", - " lengths = [int(x.strip()) for x in lengths.split(',') if len(lengths.split(',')) == 2]\n", - "\n", - " if len(lengths) != 2:\n", - " raise ValueError(\"Incorrect specification of binder lengths.\")\n", - "\n", - " settings = {\n", - " \"design_path\": design_path,\n", - " \"binder_name\": binder_name,\n", - " \"starting_pdb\": starting_pdb,\n", - " \"chains\": chains,\n", - " \"target_hotspot_residues\": target_hotspot_residues,\n", - " \"lengths\": lengths,\n", - " \"number_of_final_designs\": number_of_final_designs\n", - " }\n", - "\n", - " target_settings_path = os.path.join(design_path, binder_name+\".json\")\n", - " os.makedirs(design_path, exist_ok=True)\n", - "\n", - " with open(target_settings_path, 'w') as f:\n", - " json.dump(settings, f, indent=4)\n", - "\n", - "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", - "print(f\"Binder design settings updated at: {currenttime}\")\n", - "print(f\"New .json file with target settings has been generated in: {target_settings_path}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qcEjqCIlhire" - }, - "outputs": [], - "source": [ - "#@title Advanced settings\n", - "# @markdown ---\n", - "# @markdown Which binder design protocol to run? Default is recommended. \"Beta-sheet\" promotes the design of more beta sheeted proteins, but requires more sampling. \"Peptide\" is optimised for helical peptide binders.\n", - "design_protocol = \"Default\" # @param [\"Default\",\"Beta-sheet\",\"Peptide\"]\n", - "# @markdown What prediction protocol to use?. \"Default\" performs single sequence prediction of the binder. \"HardTarget\" uses initial guess to improve complex prediction for difficult targets, but might introduce some bias.\n", - "prediction_protocol = \"Default\" # @param [\"Default\",\"HardTarget\"]\n", - "# @markdown What interface design method to use?. \"AlphaFold2\" is the default, interface is generated by AlphaFold2. \"MPNN\" uses soluble MPNN to optimise the interface.\n", - "interface_protocol = \"AlphaFold2\" # @param [\"AlphaFold2\",\"MPNN\"]\n", - "# @markdown What target template protocol to use? \"Default\" allows for limited amount flexibility. \"Masked\" allows for greater target flexibility on both sidechain and backbone level.\n", - "template_protocol = \"Default\" # @param [\"Default\",\"Masked\"]\n", - "# @markdown ---\n", - "\n", - "if design_protocol == \"Default\":\n", - " design_protocol_tag = \"default_4stage_multimer\"\n", - "elif design_protocol == \"Beta-sheet\":\n", - " design_protocol_tag = \"betasheet_4stage_multimer\"\n", - "elif design_protocol == \"Peptide\":\n", - " design_protocol_tag = \"peptide_3stage_multimer\"\n", - "else:\n", - " raise ValueError(f\"Unsupported design protocol\")\n", - "\n", - "if interface_protocol == \"AlphaFold2\":\n", - " interface_protocol_tag = \"\"\n", - "elif interface_protocol == \"MPNN\":\n", - " interface_protocol_tag = \"_mpnn\"\n", - "else:\n", - " raise ValueError(f\"Unsupported interface protocol\")\n", - "\n", - "if template_protocol == \"Default\":\n", - " template_protocol_tag = \"\"\n", - "elif template_protocol == \"Masked\":\n", - " template_protocol_tag = \"_flexible\"\n", - "else:\n", - " raise ValueError(f\"Unsupported template protocol\")\n", - "\n", - "if design_protocol in [\"Peptide\"]:\n", - " prediction_protocol_tag = \"\"\n", - "else:\n", - " if prediction_protocol == \"Default\":\n", - " prediction_protocol_tag = \"\"\n", - " elif prediction_protocol == \"HardTarget\":\n", - " prediction_protocol_tag = \"_hardtarget\"\n", - " else:\n", - " raise ValueError(f\"Unsupported prediction protocol\")\n", - "\n", - "advanced_settings_path = \"/content/bindcraft/settings_advanced/\" + design_protocol_tag + interface_protocol_tag + template_protocol_tag + prediction_protocol_tag + \".json\"\n", - "\n", - "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", - "print(f\"Advanced design settings updated at: {currenttime}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "r-OpCVe4hi5Q" - }, - "outputs": [], - "source": [ - "#@title Filters\n", - "# @markdown ---\n", - "# @markdown Which filters for designs to use? \"Default\" are recommended, \"Peptide\" are for the design of peptide binders, \"Relaxed\" are more permissive but may result in fewer experimental successes, \"Peptide_Relaxed\" are more permissive filters for non-helical peptides, \"None\" is for benchmarking.\n", - "filter_option = \"Default\" # @param [\"Default\", \"Peptide\", \"Relaxed\", \"Peptide_Relaxed\", \"None\"]\n", - "# @markdown ---\n", - "\n", - "if filter_option == \"Default\":\n", - " filter_settings_path = \"/content/bindcraft/settings_filters/default_filters.json\"\n", - "elif filter_option == \"Peptide\":\n", - " filter_settings_path = \"/content/bindcraft/settings_filters/peptide_filters.json\"\n", - "elif filter_option == \"Relaxed\":\n", - " filter_settings_path = \"/content/bindcraft/settings_filters/relaxed_filters.json\"\n", - "elif filter_option == \"Peptide_Relaxed\":\n", - " filter_settings_path = \"/content/bindcraft/settings_filters/peptide_relaxed_filters.json\"\n", - "elif filter_option == \"None\":\n", - " filter_settings_path = \"/content/bindcraft/settings_filters/no_filters.json\"\n", - "else:\n", - " raise ValueError(f\"Unsupported filter type\")\n", - "\n", - "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", - "print(f\"Filter settings updated at: {currenttime}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BR3gtmcChtvX" - }, - "source": [ - "# Everything is set, BindCraft is ready to run!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "LgRFO3EKAnM5" - }, - "outputs": [], - "source": [ - "# @title Import functions and settings\n", - "from bindcraft.functions import *\n", - "\n", - "args = {\"settings\":target_settings_path,\n", - " \"filters\":filter_settings_path,\n", - " \"advanced\":advanced_settings_path}\n", - "\n", - "# Check if JAX-capable GPU is available, otherwise exit\n", - "check_jax_gpu()\n", - "\n", - "# perform checks of input setting files\n", - "settings_path, filters_path, advanced_path = (args[\"settings\"], args[\"filters\"], args[\"advanced\"])\n", - "\n", - "### load settings from JSON\n", - "target_settings, advanced_settings, filters = load_json_settings(settings_path, filters_path, advanced_path)\n", - "\n", - "settings_file = os.path.basename(settings_path).split('.')[0]\n", - "filters_file = os.path.basename(filters_path).split('.')[0]\n", - "advanced_file = os.path.basename(advanced_path).split('.')[0]\n", - "\n", - "### load AF2 model settings\n", - "design_models, prediction_models, multimer_validation = load_af2_models(advanced_settings[\"use_multimer_design\"])\n", - "\n", - "### perform checks on advanced_settings\n", - "bindcraft_folder = \"colab\"\n", - "advanced_settings = perform_advanced_settings_check(advanced_settings, bindcraft_folder)\n", - "\n", - "### generate directories, design path names can be found within the function\n", - "design_paths = generate_directories(target_settings[\"design_path\"])\n", - "\n", - "### generate dataframes\n", - "trajectory_labels, design_labels, final_labels = generate_dataframe_labels()\n", - "\n", - "trajectory_csv = os.path.join(target_settings[\"design_path\"], 'trajectory_stats.csv')\n", - "mpnn_csv = os.path.join(target_settings[\"design_path\"], 'mpnn_design_stats.csv')\n", - "final_csv = os.path.join(target_settings[\"design_path\"], 'final_design_stats.csv')\n", - "failure_csv = os.path.join(target_settings[\"design_path\"], 'failure_csv.csv')\n", - "\n", - "create_dataframe(trajectory_csv, trajectory_labels)\n", - "create_dataframe(mpnn_csv, design_labels)\n", - "create_dataframe(final_csv, final_labels)\n", - "generate_filter_pass_csv(failure_csv, args[\"filters\"])\n", - "\n", - "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", - "print(f\"Loaded design functions and settings at: {currenttime}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "1sOAn_xyEZKo" - }, - "outputs": [], - "source": [ - "#@title Initialise PyRosetta\n", - "\n", - "####################################\n", - "####################################\n", - "####################################\n", - "### initialise PyRosetta\n", - "pr.init(f'-ignore_unrecognized_res -ignore_zero_occupancy -mute all -holes:dalphaball {advanced_settings[\"dalphaball_path\"]} -corrections::beta_nov16 true -relax:default_repeats 1')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ZH2hVVrpzn-o" - }, - "outputs": [], - "source": [ - "#@title Run BindCraft!\n", - "####################################\n", - "###################### BindCraft Run\n", - "####################################\n", - "# Colab-specific: live displays\n", - "num_sampled_trajectories = len(pd.read_csv(trajectory_csv))\n", - "num_accepted_designs = len(pd.read_csv(final_csv))\n", - "sampled_trajectories_label = HTML(value=f\"

Sampled Trajectories: {num_sampled_trajectories}

\")\n", - "accepted_designs_label = HTML(value=f\"

Accepted Designs: {num_accepted_designs}

\")\n", - "display(VBox([sampled_trajectories_label, accepted_designs_label]))\n", - "\n", - "# initialise counters\n", - "script_start_time = time.time()\n", - "trajectory_n = 1\n", - "accepted_designs = 0\n", - "\n", - "### start design loop\n", - "while True:\n", - " ### check if we have the target number of binders\n", - " final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels)\n", - "\n", - " if final_designs_reached:\n", - " # stop design loop execution\n", - " break\n", - "\n", - " ### check if we reached maximum allowed trajectories\n", - " max_trajectories_reached = check_n_trajectories(design_paths, advanced_settings)\n", - "\n", - " if max_trajectories_reached:\n", - " break\n", - "\n", - " ### Initialise design\n", - " # measure time to generate design\n", - " trajectory_start_time = time.time()\n", - "\n", - " # generate random seed to vary designs\n", - " seed = int(np.random.randint(0, high=999999, size=1, dtype=int)[0])\n", - "\n", - " # sample binder design length randomly from defined distribution\n", - " samples = np.arange(min(target_settings[\"lengths\"]), max(target_settings[\"lengths\"]) + 1)\n", - " length = np.random.choice(samples)\n", - "\n", - " # load desired helicity value to sample different secondary structure contents\n", - " helicity_value = load_helicity(advanced_settings)\n", - "\n", - " # generate design name and check if same trajectory was already run\n", - " design_name = target_settings[\"binder_name\"] + \"_l\" + str(length) + \"_s\"+ str(seed)\n", - " trajectory_dirs = [\"Trajectory\", \"Trajectory/Relaxed\", \"Trajectory/LowConfidence\", \"Trajectory/Clashing\"]\n", - " trajectory_exists = any(os.path.exists(os.path.join(design_paths[trajectory_dir], design_name + \".pdb\")) for trajectory_dir in trajectory_dirs)\n", - "\n", - " if not trajectory_exists:\n", - " print(\"Starting trajectory: \"+design_name)\n", - "\n", - " ### Begin binder hallucination\n", - " trajectory = binder_hallucination(design_name, target_settings[\"starting_pdb\"], target_settings[\"chains\"],\n", - " target_settings[\"target_hotspot_residues\"], length, seed, helicity_value,\n", - " design_models, advanced_settings, design_paths, failure_csv)\n", - " trajectory_metrics = copy_dict(trajectory._tmp[\"best\"][\"aux\"][\"log\"]) # contains plddt, ptm, i_ptm, pae, i_pae\n", - " trajectory_pdb = os.path.join(design_paths[\"Trajectory\"], design_name + \".pdb\")\n", - "\n", - " # round the metrics to two decimal places\n", - " trajectory_metrics = {k: round(v, 2) if isinstance(v, float) else v for k, v in trajectory_metrics.items()}\n", - "\n", - " # time trajectory\n", - " trajectory_time = time.time() - trajectory_start_time\n", - " trajectory_time_text = f\"{'%d hours, %d minutes, %d seconds' % (int(trajectory_time // 3600), int((trajectory_time % 3600) // 60), int(trajectory_time % 60))}\"\n", - " print(\"Starting trajectory took: \"+trajectory_time_text)\n", - " print(\"\")\n", - "\n", - " # Proceed if there is no trajectory termination signal\n", - " if trajectory.aux[\"log\"][\"terminate\"] == \"\":\n", - " # Relax binder to calculate statistics\n", - " trajectory_relaxed = os.path.join(design_paths[\"Trajectory/Relaxed\"], design_name + \".pdb\")\n", - " pr_relax(trajectory_pdb, trajectory_relaxed)\n", - "\n", - " # define binder chain, placeholder in case multi-chain parsing in ColabDesign gets changed\n", - " binder_chain = \"B\"\n", - "\n", - " # Calculate clashes before and after relaxation\n", - " num_clashes_trajectory = calculate_clash_score(trajectory_pdb)\n", - " num_clashes_relaxed = calculate_clash_score(trajectory_relaxed)\n", - "\n", - " # secondary structure content of starting trajectory binder and interface\n", - " trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_i_plddt, trajectory_ss_plddt = calc_ss_percentage(trajectory_pdb, advanced_settings, binder_chain)\n", - "\n", - " # analyze interface scores for relaxed af2 trajectory\n", - " trajectory_interface_scores, trajectory_interface_AA, trajectory_interface_residues = score_interface(trajectory_relaxed, binder_chain)\n", - "\n", - " # starting binder sequence\n", - " trajectory_sequence = trajectory.get_seq(get_best=True)[0]\n", - "\n", - " # analyze sequence\n", - " traj_seq_notes = validate_design_sequence(trajectory_sequence, num_clashes_relaxed, advanced_settings)\n", - "\n", - " # target structure RMSD compared to input PDB\n", - " trajectory_target_rmsd = unaligned_rmsd(target_settings[\"starting_pdb\"], trajectory_pdb, target_settings[\"chains\"], 'A')\n", - "\n", - " # save trajectory statistics into CSV\n", - " trajectory_data = [design_name, advanced_settings[\"design_algorithm\"], length, seed, helicity_value, target_settings[\"target_hotspot_residues\"], trajectory_sequence, trajectory_interface_residues,\n", - " trajectory_metrics['plddt'], trajectory_metrics['ptm'], trajectory_metrics['i_ptm'], trajectory_metrics['pae'], trajectory_metrics['i_pae'],\n", - " trajectory_i_plddt, trajectory_ss_plddt, num_clashes_trajectory, num_clashes_relaxed, trajectory_interface_scores['binder_score'],\n", - " trajectory_interface_scores['surface_hydrophobicity'], trajectory_interface_scores['interface_sc'], trajectory_interface_scores['interface_packstat'],\n", - " trajectory_interface_scores['interface_dG'], trajectory_interface_scores['interface_dSASA'], trajectory_interface_scores['interface_dG_SASA_ratio'],\n", - " trajectory_interface_scores['interface_fraction'], trajectory_interface_scores['interface_hydrophobicity'], trajectory_interface_scores['interface_nres'], trajectory_interface_scores['interface_interface_hbonds'],\n", - " trajectory_interface_scores['interface_hbond_percentage'], trajectory_interface_scores['interface_delta_unsat_hbonds'], trajectory_interface_scores['interface_delta_unsat_hbonds_percentage'],\n", - " trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_interface_AA, trajectory_target_rmsd,\n", - " trajectory_time_text, traj_seq_notes, settings_file, filters_file, advanced_file]\n", - " insert_data(trajectory_csv, trajectory_data)\n", - "\n", - " if advanced_settings[\"enable_mpnn\"]:\n", - " # initialise MPNN counters\n", - " mpnn_n = 1\n", - " accepted_mpnn = 0\n", - " mpnn_dict = {}\n", - " design_start_time = time.time()\n", - "\n", - " ### MPNN redesign of starting binder\n", - " mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n", - " existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values)\n", - "\n", - " # create set of MPNN sequences with allowed amino acid composition\n", - " restricted_AAs = set(aa.strip().upper() for aa in advanced_settings[\"omit_AAs\"].split(',')) if advanced_settings[\"force_reject_AA\"] else set()\n", - "\n", - " mpnn_sequences = sorted({\n", - " mpnn_trajectories['seq'][n][-length:]: {\n", - " 'seq': mpnn_trajectories['seq'][n][-length:],\n", - " 'score': mpnn_trajectories['score'][n],\n", - " 'seqid': mpnn_trajectories['seqid'][n]\n", - " } for n in range(advanced_settings[\"num_seqs\"])\n", - " if (not restricted_AAs or not any(aa in mpnn_trajectories['seq'][n][-length:].upper() for aa in restricted_AAs))\n", - " and mpnn_trajectories['seq'][n][-length:] not in existing_mpnn_sequences\n", - " }.values(), key=lambda x: x['score'])\n", - "\n", - " del existing_mpnn_sequences\n", - "\n", - " # check whether any sequences are left after amino acid rejection and duplication check, and if yes proceed with prediction\n", - " if mpnn_sequences:\n", - " # add optimisation for increasing recycles if trajectory is beta sheeted\n", - " if advanced_settings[\"optimise_beta\"] and float(trajectory_beta) > 15:\n", - " advanced_settings[\"num_recycles_validation\"] = advanced_settings[\"optimise_beta_recycles_valid\"]\n", - "\n", - " ### Compile prediction models once for faster prediction of MPNN sequences\n", - " clear_mem()\n", - " # compile complex prediction model\n", - " complex_prediction_model = mk_afdesign_model(protocol=\"binder\", num_recycles=advanced_settings[\"num_recycles_validation\"], data_dir=advanced_settings[\"af_params_dir\"],\n", - " use_multimer=multimer_validation)\n", - " complex_prediction_model.prep_inputs(pdb_filename=target_settings[\"starting_pdb\"], chain=target_settings[\"chains\"], binder_len=length, rm_target_seq=advanced_settings[\"rm_template_seq_predict\"],\n", - " rm_target_sc=advanced_settings[\"rm_template_sc_predict\"])\n", - "\n", - " # compile binder monomer prediction model\n", - " binder_prediction_model = mk_afdesign_model(protocol=\"hallucination\", use_templates=False, initial_guess=False,\n", - " use_initial_atom_pos=False, num_recycles=advanced_settings[\"num_recycles_validation\"],\n", - " data_dir=advanced_settings[\"af_params_dir\"], use_multimer=multimer_validation)\n", - " binder_prediction_model.prep_inputs(length=length)\n", - "\n", - " # iterate over designed sequences\n", - " for mpnn_sequence in mpnn_sequences:\n", - " mpnn_time = time.time()\n", - "\n", - " # generate mpnn design name numbering\n", - " mpnn_design_name = design_name + \"_mpnn\" + str(mpnn_n)\n", - " mpnn_score = round(mpnn_sequence['score'],2)\n", - " mpnn_seqid = round(mpnn_sequence['seqid'],2)\n", - "\n", - " # add design to dictionary\n", - " mpnn_dict[mpnn_design_name] = {'seq': mpnn_sequence['seq'], 'score': mpnn_score, 'seqid': mpnn_seqid}\n", - "\n", - " # save fasta sequence\n", - " if advanced_settings[\"save_mpnn_fasta\"] is True:\n", - " save_fasta(mpnn_design_name, mpnn_sequence['seq'], design_paths)\n", - "\n", - " ### Predict mpnn redesigned binder complex using masked templates\n", - " mpnn_complex_statistics, pass_af2_filters = predict_binder_complex(complex_prediction_model,\n", - " mpnn_sequence['seq'], mpnn_design_name,\n", - " target_settings[\"starting_pdb\"], target_settings[\"chains\"],\n", - " length, trajectory_pdb, prediction_models, advanced_settings,\n", - " filters, design_paths, failure_csv)\n", - "\n", - " # if AF2 filters are not passed then skip the scoring\n", - " if not pass_af2_filters:\n", - " print(f\"Base AF2 filters not passed for {mpnn_design_name}, skipping interface scoring\")\n", - " mpnn_n += 1\n", - " continue\n", - "\n", - " # calculate statistics for each model individually\n", - " for model_num in prediction_models:\n", - " mpnn_design_pdb = os.path.join(design_paths[\"MPNN\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", - " mpnn_design_relaxed = os.path.join(design_paths[\"MPNN/Relaxed\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", - "\n", - " if os.path.exists(mpnn_design_pdb):\n", - " # Calculate clashes before and after relaxation\n", - " num_clashes_mpnn = calculate_clash_score(mpnn_design_pdb)\n", - " num_clashes_mpnn_relaxed = calculate_clash_score(mpnn_design_relaxed)\n", - "\n", - " # analyze interface scores for relaxed af2 trajectory\n", - " mpnn_interface_scores, mpnn_interface_AA, mpnn_interface_residues = score_interface(mpnn_design_relaxed, binder_chain)\n", - "\n", - " # secondary structure content of starting trajectory binder\n", - " mpnn_alpha, mpnn_beta, mpnn_loops, mpnn_alpha_interface, mpnn_beta_interface, mpnn_loops_interface, mpnn_i_plddt, mpnn_ss_plddt = calc_ss_percentage(mpnn_design_pdb, advanced_settings, binder_chain)\n", - "\n", - " # unaligned RMSD calculate to determine if binder is in the designed binding site\n", - " rmsd_site = unaligned_rmsd(trajectory_pdb, mpnn_design_pdb, binder_chain, binder_chain)\n", - "\n", - " # calculate RMSD of target compared to input PDB\n", - " target_rmsd = target_pdb_rmsd(mpnn_design_pdb, target_settings[\"starting_pdb\"], target_settings[\"chains\"])\n", - "\n", - " # add the additional statistics to the mpnn_complex_statistics dictionary\n", - " mpnn_complex_statistics[model_num+1].update({\n", - " 'i_pLDDT': mpnn_i_plddt,\n", - " 'ss_pLDDT': mpnn_ss_plddt,\n", - " 'Unrelaxed_Clashes': num_clashes_mpnn,\n", - " 'Relaxed_Clashes': num_clashes_mpnn_relaxed,\n", - " 'Binder_Energy_Score': mpnn_interface_scores['binder_score'],\n", - " 'Surface_Hydrophobicity': mpnn_interface_scores['surface_hydrophobicity'],\n", - " 'ShapeComplementarity': mpnn_interface_scores['interface_sc'],\n", - " 'PackStat': mpnn_interface_scores['interface_packstat'],\n", - " 'dG': mpnn_interface_scores['interface_dG'],\n", - " 'dSASA': mpnn_interface_scores['interface_dSASA'],\n", - " 'dG/dSASA': mpnn_interface_scores['interface_dG_SASA_ratio'],\n", - " 'Interface_SASA_%': mpnn_interface_scores['interface_fraction'],\n", - " 'Interface_Hydrophobicity': mpnn_interface_scores['interface_hydrophobicity'],\n", - " 'n_InterfaceResidues': mpnn_interface_scores['interface_nres'],\n", - " 'n_InterfaceHbonds': mpnn_interface_scores['interface_interface_hbonds'],\n", - " 'InterfaceHbondsPercentage': mpnn_interface_scores['interface_hbond_percentage'],\n", - " 'n_InterfaceUnsatHbonds': mpnn_interface_scores['interface_delta_unsat_hbonds'],\n", - " 'InterfaceUnsatHbondsPercentage': mpnn_interface_scores['interface_delta_unsat_hbonds_percentage'],\n", - " 'InterfaceAAs': mpnn_interface_AA,\n", - " 'Interface_Helix%': mpnn_alpha_interface,\n", - " 'Interface_BetaSheet%': mpnn_beta_interface,\n", - " 'Interface_Loop%': mpnn_loops_interface,\n", - " 'Binder_Helix%': mpnn_alpha,\n", - " 'Binder_BetaSheet%': mpnn_beta,\n", - " 'Binder_Loop%': mpnn_loops,\n", - " 'Hotspot_RMSD': rmsd_site,\n", - " 'Target_RMSD': target_rmsd\n", - " })\n", - "\n", - " # save space by removing unrelaxed predicted mpnn complex pdb?\n", - " if advanced_settings[\"remove_unrelaxed_complex\"]:\n", - " os.remove(mpnn_design_pdb)\n", - "\n", - " # calculate complex averages\n", - " mpnn_complex_averages = calculate_averages(mpnn_complex_statistics, handle_aa=True)\n", - "\n", - " ### Predict binder alone in single sequence mode\n", - " binder_statistics = predict_binder_alone(binder_prediction_model, mpnn_sequence['seq'], mpnn_design_name, length,\n", - " trajectory_pdb, binder_chain, prediction_models, advanced_settings, design_paths)\n", - "\n", - " # extract RMSDs of binder to the original trajectory\n", - " for model_num in prediction_models:\n", - " mpnn_binder_pdb = os.path.join(design_paths[\"MPNN/Binder\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", - "\n", - " if os.path.exists(mpnn_binder_pdb):\n", - " rmsd_binder = unaligned_rmsd(trajectory_pdb, mpnn_binder_pdb, binder_chain, \"A\")\n", - "\n", - " # append to statistics\n", - " binder_statistics[model_num+1].update({\n", - " 'Binder_RMSD': rmsd_binder\n", - " })\n", - "\n", - " # save space by removing binder monomer models?\n", - " if advanced_settings[\"remove_binder_monomer\"]:\n", - " os.remove(mpnn_binder_pdb)\n", - "\n", - " # calculate binder averages\n", - " binder_averages = calculate_averages(binder_statistics)\n", - "\n", - " # analyze sequence to make sure there are no cysteins and it contains residues that absorb UV for detection\n", - " seq_notes = validate_design_sequence(mpnn_sequence['seq'], mpnn_complex_averages.get('Relaxed_Clashes', None), advanced_settings)\n", - "\n", - " # measure time to generate design\n", - " mpnn_end_time = time.time() - mpnn_time\n", - " elapsed_mpnn_text = f\"{'%d hours, %d minutes, %d seconds' % (int(mpnn_end_time // 3600), int((mpnn_end_time % 3600) // 60), int(mpnn_end_time % 60))}\"\n", - "\n", - "\n", - " # Insert statistics about MPNN design into CSV, will return None if corresponding model does note exist\n", - " model_numbers = range(1, 6)\n", - " statistics_labels = ['pLDDT', 'pTM', 'i_pTM', 'pAE', 'i_pAE', 'i_pLDDT', 'ss_pLDDT', 'Unrelaxed_Clashes', 'Relaxed_Clashes', 'Binder_Energy_Score', 'Surface_Hydrophobicity',\n", - " 'ShapeComplementarity', 'PackStat', 'dG', 'dSASA', 'dG/dSASA', 'Interface_SASA_%', 'Interface_Hydrophobicity', 'n_InterfaceResidues', 'n_InterfaceHbonds', 'InterfaceHbondsPercentage',\n", - " 'n_InterfaceUnsatHbonds', 'InterfaceUnsatHbondsPercentage', 'Interface_Helix%', 'Interface_BetaSheet%', 'Interface_Loop%', 'Binder_Helix%',\n", - " 'Binder_BetaSheet%', 'Binder_Loop%', 'InterfaceAAs', 'Hotspot_RMSD', 'Target_RMSD']\n", - "\n", - " # Initialize mpnn_data with the non-statistical data\n", - " mpnn_data = [mpnn_design_name, advanced_settings[\"design_algorithm\"], length, seed, helicity_value, target_settings[\"target_hotspot_residues\"], mpnn_sequence['seq'], mpnn_interface_residues, mpnn_score, mpnn_seqid]\n", - "\n", - " # Add the statistical data for mpnn_complex\n", - " for label in statistics_labels:\n", - " mpnn_data.append(mpnn_complex_averages.get(label, None))\n", - " for model in model_numbers:\n", - " mpnn_data.append(mpnn_complex_statistics.get(model, {}).get(label, None))\n", - "\n", - " # Add the statistical data for binder\n", - " for label in ['pLDDT', 'pTM', 'pAE', 'Binder_RMSD']: # These are the labels for binder alone\n", - " mpnn_data.append(binder_averages.get(label, None))\n", - " for model in model_numbers:\n", - " mpnn_data.append(binder_statistics.get(model, {}).get(label, None))\n", - "\n", - " # Add the remaining non-statistical data\n", - " mpnn_data.extend([elapsed_mpnn_text, seq_notes, settings_file, filters_file, advanced_file])\n", - "\n", - " # insert data into csv\n", - " insert_data(mpnn_csv, mpnn_data)\n", - "\n", - " # find best model number by pLDDT\n", - " plddt_values = {i: mpnn_data[i] for i in range(11, 15) if mpnn_data[i] is not None}\n", - "\n", - " # Find the key with the highest value\n", - " highest_plddt_key = int(max(plddt_values, key=plddt_values.get))\n", - "\n", - " # Output the number part of the key\n", - " best_model_number = highest_plddt_key - 10\n", - " best_model_pdb = os.path.join(design_paths[\"MPNN/Relaxed\"], f\"{mpnn_design_name}_model{best_model_number}.pdb\")\n", - "\n", - " # run design data against filter thresholds\n", - " filter_conditions = check_filters(mpnn_data, design_labels, filters)\n", - " if filter_conditions == True:\n", - " print(mpnn_design_name+\" passed all filters\")\n", - " accepted_mpnn += 1\n", - " accepted_designs += 1\n", - "\n", - " # copy designs to accepted folder\n", - " shutil.copy(best_model_pdb, design_paths[\"Accepted\"])\n", - "\n", - " # insert data into final csv\n", - " final_data = [''] + mpnn_data\n", - " insert_data(final_csv, final_data)\n", - "\n", - " # copy animation from accepted trajectory\n", - " if advanced_settings[\"save_design_animations\"]:\n", - " accepted_animation = os.path.join(design_paths[\"Accepted/Animation\"], f\"{design_name}.html\")\n", - " if not os.path.exists(accepted_animation):\n", - " shutil.copy(os.path.join(design_paths[\"Trajectory/Animation\"], f\"{design_name}.html\"), accepted_animation)\n", - "\n", - " # copy plots of accepted trajectory\n", - " plot_files = os.listdir(design_paths[\"Trajectory/Plots\"])\n", - " plots_to_copy = [f for f in plot_files if f.startswith(design_name) and f.endswith('.png')]\n", - " for accepted_plot in plots_to_copy:\n", - " source_plot = os.path.join(design_paths[\"Trajectory/Plots\"], accepted_plot)\n", - " target_plot = os.path.join(design_paths[\"Accepted/Plots\"], accepted_plot)\n", - " if not os.path.exists(target_plot):\n", - " shutil.copy(source_plot, target_plot)\n", - "\n", - " else:\n", - " print(f\"Unmet filter conditions for {mpnn_design_name}\")\n", - " failure_df = pd.read_csv(failure_csv)\n", - " special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')\n", - " incremented_columns = set()\n", - "\n", - " for column in filter_conditions:\n", - " base_column = column\n", - " for prefix in special_prefixes:\n", - " if column.startswith(prefix):\n", - " base_column = column.split('_', 1)[1]\n", - "\n", - " if base_column not in incremented_columns:\n", - " failure_df[base_column] = failure_df[base_column] + 1\n", - " incremented_columns.add(base_column)\n", - "\n", - " failure_df.to_csv(failure_csv, index=False)\n", - " shutil.copy(best_model_pdb, design_paths[\"Rejected\"])\n", - "\n", - " # increase MPNN design number\n", - " mpnn_n += 1\n", - "\n", - " # if enough mpnn sequences of the same trajectory pass filters then stop\n", - " if accepted_mpnn >= advanced_settings[\"max_mpnn_sequences\"]:\n", - " break\n", - "\n", - " if accepted_mpnn >= 1:\n", - " print(\"Found \"+str(accepted_mpnn)+\" MPNN designs passing filters\")\n", - " else:\n", - " print(\"No accepted MPNN designs found for this trajectory.\")\n", - "\n", - " else:\n", - " print('Duplicate MPNN designs sampled with different trajectory, skipping current trajectory optimisation')\n", - "\n", - " # save space by removing unrelaxed design trajectory PDB\n", - " if advanced_settings[\"remove_unrelaxed_trajectory\"]:\n", - " os.remove(trajectory_pdb)\n", - "\n", - " # measure time it took to generate designs for one trajectory\n", - " design_time = time.time() - design_start_time\n", - " design_time_text = f\"{'%d hours, %d minutes, %d seconds' % (int(design_time // 3600), int((design_time % 3600) // 60), int(design_time % 60))}\"\n", - " print(\"Design and validation of trajectory \"+design_name+\" took: \"+design_time_text)\n", - "\n", - " # analyse the rejection rate of trajectories to see if we need to readjust the design weights\n", - " if trajectory_n >= advanced_settings[\"start_monitoring\"] and advanced_settings[\"enable_rejection_check\"]:\n", - " acceptance = accepted_designs / trajectory_n\n", - " if not acceptance >= advanced_settings[\"acceptance_rate\"]:\n", - " print(\"The ratio of successful designs is lower than defined acceptance rate! Consider changing your design settings!\")\n", - " print(\"Script execution stopping...\")\n", - " break\n", - "\n", - " # increase trajectory number\n", - " trajectory_n += 1\n", - "\n", - " # Colab-specific: update counters\n", - " num_sampled_trajectories = len(pd.read_csv(trajectory_csv))\n", - " num_accepted_designs = len(pd.read_csv(final_csv))\n", - " sampled_trajectories_label.value = f\"Sampled trajectories: {num_sampled_trajectories}\"\n", - " accepted_designs_label.value = f\"Accepted designs: {num_accepted_designs}\"\n", - "\n", - "### Script finished\n", - "elapsed_time = time.time() - script_start_time\n", - "elapsed_text = f\"{'%d hours, %d minutes, %d seconds' % (int(elapsed_time // 3600), int((elapsed_time % 3600) // 60), int(elapsed_time % 60))}\"\n", - "print(\"Finished all designs. Script execution for \"+str(trajectory_n)+\" trajectories took: \"+elapsed_text)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jdmYnBypaUHR" - }, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "U5qKDGGabhZo" - }, - "outputs": [], - "source": [ - "#@title Consolidate & Rank Designs\n", - "#@markdown ---\n", - "accepted_binders = [f for f in os.listdir(design_paths[\"Accepted\"]) if f.endswith('.pdb')]\n", - "\n", - "for f in os.listdir(design_paths[\"Accepted/Ranked\"]):\n", - " os.remove(os.path.join(design_paths[\"Accepted/Ranked\"], f))\n", - "\n", - "# load dataframe of designed binders\n", - "design_df = pd.read_csv(mpnn_csv)\n", - "design_df = design_df.sort_values('Average_i_pTM', ascending=False)\n", - "\n", - "# create final csv dataframe to copy matched rows, initialize with the column labels\n", - "final_df = pd.DataFrame(columns=final_labels)\n", - "\n", - "# check the ranking of the designs and copy them with new ranked IDs to the folder\n", - "rank = 1\n", - "for _, row in design_df.iterrows():\n", - " for binder in accepted_binders:\n", - " target_settings[\"binder_name\"], model = binder.rsplit('_model', 1)\n", - " if target_settings[\"binder_name\"] == row['Design']:\n", - " # rank and copy into ranked folder\n", - " row_data = {'Rank': rank, **{label: row[label] for label in design_labels}}\n", - " final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True)\n", - " old_path = os.path.join(design_paths[\"Accepted\"], binder)\n", - " new_path = os.path.join(design_paths[\"Accepted/Ranked\"], f\"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb\")\n", - " shutil.copyfile(old_path, new_path)\n", - "\n", - " rank += 1\n", - " break\n", - "\n", - "# save the final_df to final_csv\n", - "final_df.to_csv(final_csv, index=False)\n", - "\n", - "print(\"Designs ranked and final_designs_stats.csv generated\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "ro9PQBE9zoIw" - }, - "outputs": [], - "source": [ - "#@title Top 20 Designs\n", - "df = pd.read_csv(os.path.join(design_path, 'final_design_stats.csv'))\n", - "df.head(20)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qOtpEzCbzoX8" - }, - "outputs": [], - "source": [ - "#@title Top Design Display\n", - "import py3Dmol\n", - "import glob\n", - "from IPython.display import HTML\n", - "\n", - "#### pymol top design\n", - "top_design_dir = os.path.join(design_path, 'Accepted', 'Ranked')\n", - "top_design_pdb = glob.glob(os.path.join(top_design_dir, '1_*.pdb'))[0]\n", - "\n", - "# Visualise in PyMOL\n", - "view = py3Dmol.view()\n", - "view.addModel(open(top_design_pdb, 'r').read(),'pdb')\n", - "view.setBackgroundColor('white')\n", - "view.setStyle({'chain':'A'}, {'cartoon': {'color':'#3c5b6f'}})\n", - "view.setStyle({'chain':'B'}, {'cartoon': {'color':'#B76E79'}})\n", - "view.zoomTo()\n", - "view.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "qX0E849cdpTv" - }, - "outputs": [], - "source": [ - "#@title Display animation\n", - "import glob\n", - "from IPython.display import HTML\n", - "\n", - "#### pymol top design\n", - "top_design_dir = os.path.join(design_path, 'Accepted', 'Ranked')\n", - "top_design_pdb = glob.glob(os.path.join(top_design_dir, '1_*.pdb'))[0]\n", - "\n", - "top_design_name = os.path.basename(top_design_pdb).split('1_', 1)[1].split('_mpnn')[0]\n", - "top_design_animation = os.path.join(design_path, 'Accepted', 'Animation', f\"{top_design_name}.html\")\n", - "\n", - "# Show animation\n", - "HTML(top_design_animation)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "A100", - "machine_shape": "hm", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "MTagjooX0XC1" + }, + "source": [ + "# BindCraft: Protein binder design\n", + "\n", + "\n", + "\n", + "Simple binder design pipeline using AlphaFold2 backpropagation, MPNN, and PyRosetta. Select your target and let the script do the rest of the work and finish once you have enough designs to order!\n", + "\n", + "The designs will be saved on your Google Drive under BindCraft/[design_name]/ and you can continue running the design pipeline if the session times out and it will continue adding new designs." + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "7fMzl8JiyaXm" + }, + "outputs": [], + "source": [ + "#@title Installation\n", + "%%time\n", + "import os, time, gc, io\n", + "import contextlib\n", + "import json\n", + "from datetime import datetime\n", + "from ipywidgets import HTML, VBox\n", + "from IPython.display import display\n", + "\n", + "if not os.path.isfile(\"bindcraft/params/done.txt\"):\n", + " print(\"Installing required BindCraft components\")\n", + "\n", + " print(\"Pulling BindCraft code from Github\")\n", + " os.makedirs('/content/bindcraft/', exist_ok=True)\n", + " !git clone https://github.com/martinpacesa/BindCraft /content/bindcraft/\n", + " os.system(\"chmod +x /content/bindcraft/functions/dssp\")\n", + " os.system(\"chmod +x /content/bindcraft/functions/DAlphaBall.gcc\")\n", + "\n", + " print(\"Installing ColabDesign\")\n", + " os.system(\"(mkdir bindcraft/params; apt-get install aria2 -qq; \\\n", + " aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar; \\\n", + " tar -xf alphafold_params_2022-12-06.tar -C bindcraft/params; touch bindcraft/params/done.txt )&\")\n", + " os.system(\"pip install git+https://github.com/sokrypton/ColabDesign.git\")\n", + " # for debugging purposes\n", + " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign\")\n", + "\n", + " print(\"Installing PyRosetta\")\n", + " os.system(\"pip install pyrosetta_installer\")\n", + " with contextlib.redirect_stdout(io.StringIO()):\n", + " import pyrosetta_installer\n", + " pyrosetta_installer.install_pyrosetta(serialization=True)\n", + "\n", + " # download params\n", + " if not os.path.isfile(\"bindcraft/params/done.txt\"):\n", + " print(\"downloading AlphaFold params\")\n", + " while not os.path.isfile(\"bindcraft/params/done.txt\"):\n", + " time.sleep(5)\n", + "\n", + " print(\"BindCraft installation is finished, ready to run!\")\n", + "else:\n", + " print(\"BindCraft components already installed, ready to run!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "01IH64-ycCQY" + }, + "outputs": [], + "source": [ + "#@title Mount your Google Drive to save design results\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "print(f\"Google drive mounted at: {currenttime}\")\n", + "\n", + "bindcraft_google_drive = '/content/drive/My Drive/BindCraft/'\n", + "os.makedirs(bindcraft_google_drive, exist_ok=True)\n", + "print(\"BindCraft folder successfully created in your drive!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "RbL-S_t2hicj" + }, + "outputs": [], + "source": [ + "#@title Binder design settings\n", + "# @markdown ---\n", + "# @markdown Enter path where to save your designs. We recommend to save on Google drive so that you can continue generating at any time.\n", + "design_path = \"/content/drive/MyDrive/BindCraft/PDL1/\" # @param {\"type\":\"string\",\"placeholder\":\"/content/drive/MyDrive/BindCraft/PDL1/\"}\n", + "\n", + "# @markdown Enter the name that should be prefixed to your binders (generally target name).\n", + "binder_name = \"PDL1\" # @param {\"type\":\"string\",\"placeholder\":\"PDL1\"}\n", + "\n", + "# @markdown The path to the .pdb structure of your target. Can be an experimental or AlphaFold2 structure. We recommend trimming the structure to as small as needed, as the whole selected chains will be backpropagated through the network and can significantly increase running times.\n", + "starting_pdb = \"/content/bindcraft/example/PDL1.pdb\" # @param {\"type\":\"string\",\"placeholder\":\"/content/bindcraft/example/PDL1.pdb\"}\n", + "\n", + "# @markdown Which chains of your PDB to target? Can be one or multiple, in a comma-separated format. Other chains will be ignored during design.\n", + "chains = \"A\" # @param {\"type\":\"string\",\"placeholder\":\"A,C\"}\n", + "\n", + "# @markdown What positions to target in your protein of interest? For example `1,2-10` or chain specific `A1-10,B1-20` or entire chains `A`. If left blank, an appropriate site will be selected by the pipeline.\n", + "target_hotspot_residues = \"\" # @param {\"type\":\"string\",\"placeholder\":\"\"}\n", + "\n", + "# @markdown What is the minimum and maximum size of binders you want to design? Pipeline will randomly sample different sizes between these values.\n", + "lengths = \"60,140\" # @param {\"type\":\"string\",\"placeholder\":\"70,150\"}\n", + "\n", + "# @markdown How many binder designs passing filters do you require?\n", + "number_of_final_designs = 100 # @param {\"type\":\"integer\",\"placeholder\":\"100\"}\n", + "# @markdown ---\n", + "# @markdown Enter path on your Google drive (/content/drive/MyDrive/BindCraft/[binder_name].json) to previous target settings to continue design campaign. If left empty, it will use the settings above and generate a new settings json in your design output folder.\n", + "load_previous_target_settings = \"\" # @param {\"type\":\"string\",\"placeholder\":\"\"}\n", + "# @markdown ---\n", + "\n", + "if load_previous_target_settings:\n", + " target_settings_path = load_previous_target_settings\n", + "else:\n", + " lengths = [int(x.strip()) for x in lengths.split(',') if len(lengths.split(',')) == 2]\n", + "\n", + " if len(lengths) != 2:\n", + " raise ValueError(\"Incorrect specification of binder lengths.\")\n", + "\n", + " settings = {\n", + " \"design_path\": design_path,\n", + " \"binder_name\": binder_name,\n", + " \"starting_pdb\": starting_pdb,\n", + " \"chains\": chains,\n", + " \"target_hotspot_residues\": target_hotspot_residues,\n", + " \"lengths\": lengths,\n", + " \"number_of_final_designs\": number_of_final_designs\n", + " }\n", + "\n", + " target_settings_path = os.path.join(design_path, binder_name+\".json\")\n", + " os.makedirs(design_path, exist_ok=True)\n", + "\n", + " with open(target_settings_path, 'w') as f:\n", + " json.dump(settings, f, indent=4)\n", + "\n", + "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "print(f\"Binder design settings updated at: {currenttime}\")\n", + "print(f\"New .json file with target settings has been generated in: {target_settings_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qcEjqCIlhire" + }, + "outputs": [], + "source": [ + "#@title Advanced settings\n", + "# @markdown ---\n", + "# @markdown Which binder design protocol to run? Default is recommended. \"Beta-sheet\" promotes the design of more beta sheeted proteins, but requires more sampling. \"Peptide\" is optimised for helical peptide binders.\n", + "design_protocol = \"Default\" # @param [\"Default\",\"Beta-sheet\",\"Peptide\"]\n", + "# @markdown What prediction protocol to use?. \"Default\" performs single sequence prediction of the binder. \"HardTarget\" uses initial guess to improve complex prediction for difficult targets, but might introduce some bias.\n", + "prediction_protocol = \"Default\" # @param [\"Default\",\"HardTarget\"]\n", + "# @markdown What interface design method to use?. \"AlphaFold2\" is the default, interface is generated by AlphaFold2. \"MPNN\" uses soluble MPNN to optimise the interface.\n", + "interface_protocol = \"AlphaFold2\" # @param [\"AlphaFold2\",\"MPNN\"]\n", + "# @markdown What target template protocol to use? \"Default\" allows for limited amount flexibility. \"Masked\" allows for greater target flexibility on both sidechain and backbone level.\n", + "template_protocol = \"Default\" # @param [\"Default\",\"Masked\"]\n", + "# @markdown ---\n", + "\n", + "if design_protocol == \"Default\":\n", + " design_protocol_tag = \"default_4stage_multimer\"\n", + "elif design_protocol == \"Beta-sheet\":\n", + " design_protocol_tag = \"betasheet_4stage_multimer\"\n", + "elif design_protocol == \"Peptide\":\n", + " design_protocol_tag = \"peptide_3stage_multimer\"\n", + "else:\n", + " raise ValueError(f\"Unsupported design protocol\")\n", + "\n", + "if interface_protocol == \"AlphaFold2\":\n", + " interface_protocol_tag = \"\"\n", + "elif interface_protocol == \"MPNN\":\n", + " interface_protocol_tag = \"_mpnn\"\n", + "else:\n", + " raise ValueError(f\"Unsupported interface protocol\")\n", + "\n", + "if template_protocol == \"Default\":\n", + " template_protocol_tag = \"\"\n", + "elif template_protocol == \"Masked\":\n", + " template_protocol_tag = \"_flexible\"\n", + "else:\n", + " raise ValueError(f\"Unsupported template protocol\")\n", + "\n", + "if design_protocol in [\"Peptide\"]:\n", + " prediction_protocol_tag = \"\"\n", + "else:\n", + " if prediction_protocol == \"Default\":\n", + " prediction_protocol_tag = \"\"\n", + " elif prediction_protocol == \"HardTarget\":\n", + " prediction_protocol_tag = \"_hardtarget\"\n", + " else:\n", + " raise ValueError(f\"Unsupported prediction protocol\")\n", + "\n", + "advanced_settings_path = \"/content/bindcraft/settings_advanced/\" + design_protocol_tag + interface_protocol_tag + template_protocol_tag + prediction_protocol_tag + \".json\"\n", + "\n", + "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "print(f\"Advanced design settings updated at: {currenttime}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "r-OpCVe4hi5Q" + }, + "outputs": [], + "source": [ + "#@title Filters\n", + "# @markdown ---\n", + "# @markdown Which filters for designs to use? \"Default\" are recommended, \"Peptide\" are for the design of peptide binders, \"Relaxed\" are more permissive but may result in fewer experimental successes, \"Peptide_Relaxed\" are more permissive filters for non-helical peptides, \"None\" is for benchmarking.\n", + "filter_option = \"Default\" # @param [\"Default\", \"Peptide\", \"Relaxed\", \"Peptide_Relaxed\", \"None\"]\n", + "# @markdown ---\n", + "\n", + "if filter_option == \"Default\":\n", + " filter_settings_path = \"/content/bindcraft/settings_filters/default_filters.json\"\n", + "elif filter_option == \"Peptide\":\n", + " filter_settings_path = \"/content/bindcraft/settings_filters/peptide_filters.json\"\n", + "elif filter_option == \"Relaxed\":\n", + " filter_settings_path = \"/content/bindcraft/settings_filters/relaxed_filters.json\"\n", + "elif filter_option == \"Peptide_Relaxed\":\n", + " filter_settings_path = \"/content/bindcraft/settings_filters/peptide_relaxed_filters.json\"\n", + "elif filter_option == \"None\":\n", + " filter_settings_path = \"/content/bindcraft/settings_filters/no_filters.json\"\n", + "else:\n", + " raise ValueError(f\"Unsupported filter type\")\n", + "\n", + "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "print(f\"Filter settings updated at: {currenttime}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BR3gtmcChtvX" + }, + "source": [ + "# Everything is set, BindCraft is ready to run!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "LgRFO3EKAnM5" + }, + "outputs": [], + "source": [ + "# @title Import functions and settings\n", + "from bindcraft.functions import *\n", + "\n", + "args = {\"settings\":target_settings_path,\n", + " \"filters\":filter_settings_path,\n", + " \"advanced\":advanced_settings_path}\n", + "\n", + "# Check if JAX-capable GPU is available, otherwise exit\n", + "check_jax_gpu()\n", + "\n", + "# perform checks of input setting files\n", + "settings_path, filters_path, advanced_path = (args[\"settings\"], args[\"filters\"], args[\"advanced\"])\n", + "\n", + "### load settings from JSON\n", + "target_settings, advanced_settings, filters = load_json_settings(settings_path, filters_path, advanced_path)\n", + "\n", + "settings_file = os.path.basename(settings_path).split('.')[0]\n", + "filters_file = os.path.basename(filters_path).split('.')[0]\n", + "advanced_file = os.path.basename(advanced_path).split('.')[0]\n", + "\n", + "### load AF2 model settings\n", + "design_models, prediction_models, multimer_validation = load_af2_models(advanced_settings[\"use_multimer_design\"])\n", + "\n", + "### perform checks on advanced_settings\n", + "bindcraft_folder = \"colab\"\n", + "advanced_settings = perform_advanced_settings_check(advanced_settings, bindcraft_folder)\n", + "\n", + "### generate directories, design path names can be found within the function\n", + "design_paths = generate_directories(target_settings[\"design_path\"])\n", + "\n", + "### generate dataframes\n", + "trajectory_labels, design_labels, final_labels = generate_dataframe_labels()\n", + "\n", + "trajectory_csv = os.path.join(target_settings[\"design_path\"], 'trajectory_stats.csv')\n", + "mpnn_csv = os.path.join(target_settings[\"design_path\"], 'mpnn_design_stats.csv')\n", + "final_csv = os.path.join(target_settings[\"design_path\"], 'final_design_stats.csv')\n", + "failure_csv = os.path.join(target_settings[\"design_path\"], 'failure_csv.csv')\n", + "\n", + "create_dataframe(trajectory_csv, trajectory_labels)\n", + "create_dataframe(mpnn_csv, design_labels)\n", + "create_dataframe(final_csv, final_labels)\n", + "generate_filter_pass_csv(failure_csv, args[\"filters\"])\n", + "\n", + "currenttime = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", + "print(f\"Loaded design functions and settings at: {currenttime}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "1sOAn_xyEZKo" + }, + "outputs": [], + "source": [ + "#@title Initialise PyRosetta\n", + "\n", + "####################################\n", + "####################################\n", + "####################################\n", + "### initialise PyRosetta\n", + "pr.init(f'-ignore_unrecognized_res -ignore_zero_occupancy -mute all -holes:dalphaball {advanced_settings[\"dalphaball_path\"]} -corrections::beta_nov16 true -relax:default_repeats 1')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ZH2hVVrpzn-o" + }, + "outputs": [], + "source": [ + "#@title Run BindCraft!\n", + "####################################\n", + "###################### BindCraft Run\n", + "####################################\n", + "# Colab-specific: live displays\n", + "num_sampled_trajectories = len(pd.read_csv(trajectory_csv))\n", + "num_accepted_designs = len(pd.read_csv(final_csv))\n", + "sampled_trajectories_label = HTML(value=f\"

Sampled Trajectories: {num_sampled_trajectories}

\")\n", + "accepted_designs_label = HTML(value=f\"

Accepted Designs: {num_accepted_designs}

\")\n", + "display(VBox([sampled_trajectories_label, accepted_designs_label]))\n", + "\n", + "# initialise counters\n", + "script_start_time = time.time()\n", + "trajectory_n = 1\n", + "accepted_designs = 0\n", + "trajectory_runtime = TrajectoryDesignRuntime(advanced_settings)\n", + "\n", + "### start design loop\n", + "while True:\n", + " ### check if we have the target number of binders\n", + " final_designs_reached = check_accepted_designs(design_paths, mpnn_csv, final_labels, final_csv, advanced_settings, target_settings, design_labels)\n", + "\n", + " if final_designs_reached:\n", + " # stop design loop execution\n", + " break\n", + "\n", + " ### check if we reached maximum allowed trajectories\n", + " max_trajectories_reached = check_n_trajectories(design_paths, advanced_settings)\n", + "\n", + " if max_trajectories_reached:\n", + " break\n", + "\n", + " ### Initialise design\n", + " # measure time to generate design\n", + " trajectory_start_time = time.time()\n", + "\n", + " # generate random seed to vary designs\n", + " seed = int(np.random.randint(0, high=999999, size=1, dtype=int)[0])\n", + "\n", + " # sample binder design length randomly from defined distribution\n", + " samples = np.arange(min(target_settings[\"lengths\"]), max(target_settings[\"lengths\"]) + 1)\n", + " length = np.random.choice(samples)\n", + "\n", + " # load desired helicity value to sample different secondary structure contents\n", + " helicity_value = load_helicity(advanced_settings)\n", + "\n", + " # generate design name and check if same trajectory was already run\n", + " design_name = target_settings[\"binder_name\"] + \"_l\" + str(length) + \"_s\"+ str(seed)\n", + " trajectory_dirs = [\"Trajectory\", \"Trajectory/Relaxed\", \"Trajectory/LowConfidence\", \"Trajectory/Clashing\"]\n", + " trajectory_exists = any(os.path.exists(os.path.join(design_paths[trajectory_dir], design_name + \".pdb\")) for trajectory_dir in trajectory_dirs)\n", + "\n", + " if not trajectory_exists:\n", + " print(\"Starting trajectory: \"+design_name)\n", + "\n", + " ### Begin binder hallucination\n", + " trajectory = binder_hallucination(design_name, target_settings[\"starting_pdb\"], target_settings[\"chains\"],\n", + " target_settings[\"target_hotspot_residues\"], length, seed, helicity_value,\n", + " design_models, advanced_settings, design_paths, failure_csv,\n", + " runtime=trajectory_runtime)\n", + " trajectory_metrics = copy_dict(trajectory._tmp[\"best\"][\"aux\"][\"log\"]) # contains plddt, ptm, i_ptm, pae, i_pae\n", + " trajectory_pdb = os.path.join(design_paths[\"Trajectory\"], design_name + \".pdb\")\n", + "\n", + " # round the metrics to two decimal places\n", + " trajectory_metrics = {k: round(v, 2) if isinstance(v, float) else v for k, v in trajectory_metrics.items()}\n", + "\n", + " # time trajectory\n", + " trajectory_time = time.time() - trajectory_start_time\n", + " trajectory_time_text = f\"{'%d hours, %d minutes, %d seconds' % (int(trajectory_time // 3600), int((trajectory_time % 3600) // 60), int(trajectory_time % 60))}\"\n", + " print(\"Starting trajectory took: \"+trajectory_time_text)\n", + " print(\"\")\n", + "\n", + " # Proceed if there is no trajectory termination signal\n", + " if trajectory.aux[\"log\"][\"terminate\"] == \"\":\n", + " # Relax binder to calculate statistics\n", + " trajectory_relaxed = os.path.join(design_paths[\"Trajectory/Relaxed\"], design_name + \".pdb\")\n", + " pr_relax(trajectory_pdb, trajectory_relaxed)\n", + "\n", + " # define binder chain, placeholder in case multi-chain parsing in ColabDesign gets changed\n", + " binder_chain = \"B\"\n", + "\n", + " # Calculate clashes before and after relaxation\n", + " num_clashes_trajectory = calculate_clash_score(trajectory_pdb)\n", + " num_clashes_relaxed = calculate_clash_score(trajectory_relaxed)\n", + "\n", + " # secondary structure content of starting trajectory binder and interface\n", + " trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_i_plddt, trajectory_ss_plddt = calc_ss_percentage(trajectory_pdb, advanced_settings, binder_chain)\n", + "\n", + " # analyze interface scores for relaxed af2 trajectory\n", + " trajectory_interface_scores, trajectory_interface_AA, trajectory_interface_residues = score_interface(trajectory_relaxed, binder_chain)\n", + "\n", + " # starting binder sequence\n", + " trajectory_sequence = trajectory.get_seq(get_best=True)[0]\n", + "\n", + " # analyze sequence\n", + " traj_seq_notes = validate_design_sequence(trajectory_sequence, num_clashes_relaxed, advanced_settings)\n", + "\n", + " # target structure RMSD compared to input PDB\n", + " trajectory_target_rmsd = unaligned_rmsd(target_settings[\"starting_pdb\"], trajectory_pdb, target_settings[\"chains\"], 'A')\n", + "\n", + " # save trajectory statistics into CSV\n", + " trajectory_data = [design_name, advanced_settings[\"design_algorithm\"], length, seed, helicity_value, target_settings[\"target_hotspot_residues\"], trajectory_sequence, trajectory_interface_residues,\n", + " trajectory_metrics['plddt'], trajectory_metrics['ptm'], trajectory_metrics['i_ptm'], trajectory_metrics['pae'], trajectory_metrics['i_pae'],\n", + " trajectory_i_plddt, trajectory_ss_plddt, num_clashes_trajectory, num_clashes_relaxed, trajectory_interface_scores['binder_score'],\n", + " trajectory_interface_scores['surface_hydrophobicity'], trajectory_interface_scores['interface_sc'], trajectory_interface_scores['interface_packstat'],\n", + " trajectory_interface_scores['interface_dG'], trajectory_interface_scores['interface_dSASA'], trajectory_interface_scores['interface_dG_SASA_ratio'],\n", + " trajectory_interface_scores['interface_fraction'], trajectory_interface_scores['interface_hydrophobicity'], trajectory_interface_scores['interface_nres'], trajectory_interface_scores['interface_interface_hbonds'],\n", + " trajectory_interface_scores['interface_hbond_percentage'], trajectory_interface_scores['interface_delta_unsat_hbonds'], trajectory_interface_scores['interface_delta_unsat_hbonds_percentage'],\n", + " trajectory_alpha_interface, trajectory_beta_interface, trajectory_loops_interface, trajectory_alpha, trajectory_beta, trajectory_loops, trajectory_interface_AA, trajectory_target_rmsd,\n", + " trajectory_time_text, traj_seq_notes, settings_file, filters_file, advanced_file]\n", + " insert_data(trajectory_csv, trajectory_data)\n", + "\n", + " if advanced_settings[\"enable_mpnn\"]:\n", + " # initialise MPNN counters\n", + " mpnn_n = 1\n", + " accepted_mpnn = 0\n", + " mpnn_dict = {}\n", + " design_start_time = time.time()\n", + "\n", + " ### MPNN redesign of starting binder\n", + " mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)\n", + " trajectory_runtime.invalidate()\n", + " existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values)\n", + "\n", + " # create set of MPNN sequences with allowed amino acid composition\n", + " restricted_AAs = set(aa.strip().upper() for aa in advanced_settings[\"omit_AAs\"].split(',')) if advanced_settings[\"force_reject_AA\"] else set()\n", + "\n", + " mpnn_sequences = sorted({\n", + " mpnn_trajectories['seq'][n][-length:]: {\n", + " 'seq': mpnn_trajectories['seq'][n][-length:],\n", + " 'score': mpnn_trajectories['score'][n],\n", + " 'seqid': mpnn_trajectories['seqid'][n]\n", + " } for n in range(advanced_settings[\"num_seqs\"])\n", + " if (not restricted_AAs or not any(aa in mpnn_trajectories['seq'][n][-length:].upper() for aa in restricted_AAs))\n", + " and mpnn_trajectories['seq'][n][-length:] not in existing_mpnn_sequences\n", + " }.values(), key=lambda x: x['score'])\n", + "\n", + " del existing_mpnn_sequences\n", + "\n", + " # check whether any sequences are left after amino acid rejection and duplication check, and if yes proceed with prediction\n", + " if mpnn_sequences:\n", + " # add optimisation for increasing recycles if trajectory is beta sheeted\n", + " if advanced_settings[\"optimise_beta\"] and float(trajectory_beta) > 15:\n", + " advanced_settings[\"num_recycles_validation\"] = advanced_settings[\"optimise_beta_recycles_valid\"]\n", + "\n", + " ### Compile prediction models once for faster prediction of MPNN sequences\n", + " clear_mem()\n", + " # compile complex prediction model\n", + " complex_prediction_model = mk_afdesign_model(protocol=\"binder\", num_recycles=advanced_settings[\"num_recycles_validation\"], data_dir=advanced_settings[\"af_params_dir\"],\n", + " use_multimer=multimer_validation)\n", + " complex_prediction_model.prep_inputs(pdb_filename=target_settings[\"starting_pdb\"], chain=target_settings[\"chains\"], binder_len=length, rm_target_seq=advanced_settings[\"rm_template_seq_predict\"],\n", + " rm_target_sc=advanced_settings[\"rm_template_sc_predict\"])\n", + "\n", + " # compile binder monomer prediction model\n", + " binder_prediction_model = mk_afdesign_model(protocol=\"hallucination\", use_templates=False, initial_guess=False,\n", + " use_initial_atom_pos=False, num_recycles=advanced_settings[\"num_recycles_validation\"],\n", + " data_dir=advanced_settings[\"af_params_dir\"], use_multimer=multimer_validation)\n", + " binder_prediction_model.prep_inputs(length=length)\n", + "\n", + " # iterate over designed sequences\n", + " for mpnn_sequence in mpnn_sequences:\n", + " mpnn_time = time.time()\n", + "\n", + " # generate mpnn design name numbering\n", + " mpnn_design_name = design_name + \"_mpnn\" + str(mpnn_n)\n", + " mpnn_score = round(mpnn_sequence['score'],2)\n", + " mpnn_seqid = round(mpnn_sequence['seqid'],2)\n", + "\n", + " # add design to dictionary\n", + " mpnn_dict[mpnn_design_name] = {'seq': mpnn_sequence['seq'], 'score': mpnn_score, 'seqid': mpnn_seqid}\n", + "\n", + " # save fasta sequence\n", + " if advanced_settings[\"save_mpnn_fasta\"] is True:\n", + " save_fasta(mpnn_design_name, mpnn_sequence['seq'], design_paths)\n", + "\n", + " ### Predict mpnn redesigned binder complex using masked templates\n", + " mpnn_complex_statistics, pass_af2_filters = predict_binder_complex(complex_prediction_model,\n", + " mpnn_sequence['seq'], mpnn_design_name,\n", + " target_settings[\"starting_pdb\"], target_settings[\"chains\"],\n", + " length, trajectory_pdb, prediction_models, advanced_settings,\n", + " filters, design_paths, failure_csv)\n", + "\n", + " # if AF2 filters are not passed then skip the scoring\n", + " if not pass_af2_filters:\n", + " print(f\"Base AF2 filters not passed for {mpnn_design_name}, skipping interface scoring\")\n", + " mpnn_n += 1\n", + " continue\n", + "\n", + " # calculate statistics for each model individually\n", + " for model_num in prediction_models:\n", + " mpnn_design_pdb = os.path.join(design_paths[\"MPNN\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", + " mpnn_design_relaxed = os.path.join(design_paths[\"MPNN/Relaxed\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", + "\n", + " if os.path.exists(mpnn_design_pdb):\n", + " # Calculate clashes before and after relaxation\n", + " num_clashes_mpnn = calculate_clash_score(mpnn_design_pdb)\n", + " num_clashes_mpnn_relaxed = calculate_clash_score(mpnn_design_relaxed)\n", + "\n", + " # analyze interface scores for relaxed af2 trajectory\n", + " mpnn_interface_scores, mpnn_interface_AA, mpnn_interface_residues = score_interface(mpnn_design_relaxed, binder_chain)\n", + "\n", + " # secondary structure content of starting trajectory binder\n", + " mpnn_alpha, mpnn_beta, mpnn_loops, mpnn_alpha_interface, mpnn_beta_interface, mpnn_loops_interface, mpnn_i_plddt, mpnn_ss_plddt = calc_ss_percentage(mpnn_design_pdb, advanced_settings, binder_chain)\n", + "\n", + " # unaligned RMSD calculate to determine if binder is in the designed binding site\n", + " rmsd_site = unaligned_rmsd(trajectory_pdb, mpnn_design_pdb, binder_chain, binder_chain)\n", + "\n", + " # calculate RMSD of target compared to input PDB\n", + " target_rmsd = target_pdb_rmsd(mpnn_design_pdb, target_settings[\"starting_pdb\"], target_settings[\"chains\"])\n", + "\n", + " # add the additional statistics to the mpnn_complex_statistics dictionary\n", + " mpnn_complex_statistics[model_num+1].update({\n", + " 'i_pLDDT': mpnn_i_plddt,\n", + " 'ss_pLDDT': mpnn_ss_plddt,\n", + " 'Unrelaxed_Clashes': num_clashes_mpnn,\n", + " 'Relaxed_Clashes': num_clashes_mpnn_relaxed,\n", + " 'Binder_Energy_Score': mpnn_interface_scores['binder_score'],\n", + " 'Surface_Hydrophobicity': mpnn_interface_scores['surface_hydrophobicity'],\n", + " 'ShapeComplementarity': mpnn_interface_scores['interface_sc'],\n", + " 'PackStat': mpnn_interface_scores['interface_packstat'],\n", + " 'dG': mpnn_interface_scores['interface_dG'],\n", + " 'dSASA': mpnn_interface_scores['interface_dSASA'],\n", + " 'dG/dSASA': mpnn_interface_scores['interface_dG_SASA_ratio'],\n", + " 'Interface_SASA_%': mpnn_interface_scores['interface_fraction'],\n", + " 'Interface_Hydrophobicity': mpnn_interface_scores['interface_hydrophobicity'],\n", + " 'n_InterfaceResidues': mpnn_interface_scores['interface_nres'],\n", + " 'n_InterfaceHbonds': mpnn_interface_scores['interface_interface_hbonds'],\n", + " 'InterfaceHbondsPercentage': mpnn_interface_scores['interface_hbond_percentage'],\n", + " 'n_InterfaceUnsatHbonds': mpnn_interface_scores['interface_delta_unsat_hbonds'],\n", + " 'InterfaceUnsatHbondsPercentage': mpnn_interface_scores['interface_delta_unsat_hbonds_percentage'],\n", + " 'InterfaceAAs': mpnn_interface_AA,\n", + " 'Interface_Helix%': mpnn_alpha_interface,\n", + " 'Interface_BetaSheet%': mpnn_beta_interface,\n", + " 'Interface_Loop%': mpnn_loops_interface,\n", + " 'Binder_Helix%': mpnn_alpha,\n", + " 'Binder_BetaSheet%': mpnn_beta,\n", + " 'Binder_Loop%': mpnn_loops,\n", + " 'Hotspot_RMSD': rmsd_site,\n", + " 'Target_RMSD': target_rmsd\n", + " })\n", + "\n", + " # save space by removing unrelaxed predicted mpnn complex pdb?\n", + " if advanced_settings[\"remove_unrelaxed_complex\"]:\n", + " os.remove(mpnn_design_pdb)\n", + "\n", + " # calculate complex averages\n", + " mpnn_complex_averages = calculate_averages(mpnn_complex_statistics, handle_aa=True)\n", + "\n", + " ### Predict binder alone in single sequence mode\n", + " binder_statistics = predict_binder_alone(binder_prediction_model, mpnn_sequence['seq'], mpnn_design_name, length,\n", + " trajectory_pdb, binder_chain, prediction_models, advanced_settings, design_paths)\n", + "\n", + " # extract RMSDs of binder to the original trajectory\n", + " for model_num in prediction_models:\n", + " mpnn_binder_pdb = os.path.join(design_paths[\"MPNN/Binder\"], f\"{mpnn_design_name}_model{model_num+1}.pdb\")\n", + "\n", + " if os.path.exists(mpnn_binder_pdb):\n", + " rmsd_binder = unaligned_rmsd(trajectory_pdb, mpnn_binder_pdb, binder_chain, \"A\")\n", + "\n", + " # append to statistics\n", + " binder_statistics[model_num+1].update({\n", + " 'Binder_RMSD': rmsd_binder\n", + " })\n", + "\n", + " # save space by removing binder monomer models?\n", + " if advanced_settings[\"remove_binder_monomer\"]:\n", + " os.remove(mpnn_binder_pdb)\n", + "\n", + " # calculate binder averages\n", + " binder_averages = calculate_averages(binder_statistics)\n", + "\n", + " # analyze sequence to make sure there are no cysteins and it contains residues that absorb UV for detection\n", + " seq_notes = validate_design_sequence(mpnn_sequence['seq'], mpnn_complex_averages.get('Relaxed_Clashes', None), advanced_settings)\n", + "\n", + " # measure time to generate design\n", + " mpnn_end_time = time.time() - mpnn_time\n", + " elapsed_mpnn_text = f\"{'%d hours, %d minutes, %d seconds' % (int(mpnn_end_time // 3600), int((mpnn_end_time % 3600) // 60), int(mpnn_end_time % 60))}\"\n", + "\n", + "\n", + " # Insert statistics about MPNN design into CSV, will return None if corresponding model does note exist\n", + " model_numbers = range(1, 6)\n", + " statistics_labels = ['pLDDT', 'pTM', 'i_pTM', 'pAE', 'i_pAE', 'i_pLDDT', 'ss_pLDDT', 'Unrelaxed_Clashes', 'Relaxed_Clashes', 'Binder_Energy_Score', 'Surface_Hydrophobicity',\n", + " 'ShapeComplementarity', 'PackStat', 'dG', 'dSASA', 'dG/dSASA', 'Interface_SASA_%', 'Interface_Hydrophobicity', 'n_InterfaceResidues', 'n_InterfaceHbonds', 'InterfaceHbondsPercentage',\n", + " 'n_InterfaceUnsatHbonds', 'InterfaceUnsatHbondsPercentage', 'Interface_Helix%', 'Interface_BetaSheet%', 'Interface_Loop%', 'Binder_Helix%',\n", + " 'Binder_BetaSheet%', 'Binder_Loop%', 'InterfaceAAs', 'Hotspot_RMSD', 'Target_RMSD']\n", + "\n", + " # Initialize mpnn_data with the non-statistical data\n", + " mpnn_data = [mpnn_design_name, advanced_settings[\"design_algorithm\"], length, seed, helicity_value, target_settings[\"target_hotspot_residues\"], mpnn_sequence['seq'], mpnn_interface_residues, mpnn_score, mpnn_seqid]\n", + "\n", + " # Add the statistical data for mpnn_complex\n", + " for label in statistics_labels:\n", + " mpnn_data.append(mpnn_complex_averages.get(label, None))\n", + " for model in model_numbers:\n", + " mpnn_data.append(mpnn_complex_statistics.get(model, {}).get(label, None))\n", + "\n", + " # Add the statistical data for binder\n", + " for label in ['pLDDT', 'pTM', 'pAE', 'Binder_RMSD']: # These are the labels for binder alone\n", + " mpnn_data.append(binder_averages.get(label, None))\n", + " for model in model_numbers:\n", + " mpnn_data.append(binder_statistics.get(model, {}).get(label, None))\n", + "\n", + " # Add the remaining non-statistical data\n", + " mpnn_data.extend([elapsed_mpnn_text, seq_notes, settings_file, filters_file, advanced_file])\n", + "\n", + " # insert data into csv\n", + " insert_data(mpnn_csv, mpnn_data)\n", + "\n", + " # find best model number by pLDDT\n", + " plddt_values = {i: mpnn_data[i] for i in range(11, 15) if mpnn_data[i] is not None}\n", + "\n", + " # Find the key with the highest value\n", + " highest_plddt_key = int(max(plddt_values, key=plddt_values.get))\n", + "\n", + " # Output the number part of the key\n", + " best_model_number = highest_plddt_key - 10\n", + " best_model_pdb = os.path.join(design_paths[\"MPNN/Relaxed\"], f\"{mpnn_design_name}_model{best_model_number}.pdb\")\n", + "\n", + " # run design data against filter thresholds\n", + " filter_conditions = check_filters(mpnn_data, design_labels, filters)\n", + " if filter_conditions == True:\n", + " print(mpnn_design_name+\" passed all filters\")\n", + " accepted_mpnn += 1\n", + " accepted_designs += 1\n", + "\n", + " # copy designs to accepted folder\n", + " shutil.copy(best_model_pdb, design_paths[\"Accepted\"])\n", + "\n", + " # insert data into final csv\n", + " final_data = [''] + mpnn_data\n", + " insert_data(final_csv, final_data)\n", + "\n", + " # copy animation from accepted trajectory\n", + " if advanced_settings[\"save_design_animations\"]:\n", + " accepted_animation = os.path.join(design_paths[\"Accepted/Animation\"], f\"{design_name}.html\")\n", + " if not os.path.exists(accepted_animation):\n", + " shutil.copy(os.path.join(design_paths[\"Trajectory/Animation\"], f\"{design_name}.html\"), accepted_animation)\n", + "\n", + " # copy plots of accepted trajectory\n", + " plot_files = os.listdir(design_paths[\"Trajectory/Plots\"])\n", + " plots_to_copy = [f for f in plot_files if f.startswith(design_name) and f.endswith('.png')]\n", + " for accepted_plot in plots_to_copy:\n", + " source_plot = os.path.join(design_paths[\"Trajectory/Plots\"], accepted_plot)\n", + " target_plot = os.path.join(design_paths[\"Accepted/Plots\"], accepted_plot)\n", + " if not os.path.exists(target_plot):\n", + " shutil.copy(source_plot, target_plot)\n", + "\n", + " else:\n", + " print(f\"Unmet filter conditions for {mpnn_design_name}\")\n", + " failure_df = pd.read_csv(failure_csv)\n", + " special_prefixes = ('Average_', '1_', '2_', '3_', '4_', '5_')\n", + " incremented_columns = set()\n", + "\n", + " for column in filter_conditions:\n", + " base_column = column\n", + " for prefix in special_prefixes:\n", + " if column.startswith(prefix):\n", + " base_column = column.split('_', 1)[1]\n", + "\n", + " if base_column not in incremented_columns:\n", + " failure_df[base_column] = failure_df[base_column] + 1\n", + " incremented_columns.add(base_column)\n", + "\n", + " failure_df.to_csv(failure_csv, index=False)\n", + " shutil.copy(best_model_pdb, design_paths[\"Rejected\"])\n", + "\n", + " # increase MPNN design number\n", + " mpnn_n += 1\n", + "\n", + " # if enough mpnn sequences of the same trajectory pass filters then stop\n", + " if accepted_mpnn >= advanced_settings[\"max_mpnn_sequences\"]:\n", + " break\n", + "\n", + " if accepted_mpnn >= 1:\n", + " print(\"Found \"+str(accepted_mpnn)+\" MPNN designs passing filters\")\n", + " else:\n", + " print(\"No accepted MPNN designs found for this trajectory.\")\n", + "\n", + " else:\n", + " print('Duplicate MPNN designs sampled with different trajectory, skipping current trajectory optimisation')\n", + "\n", + " # save space by removing unrelaxed design trajectory PDB\n", + " if advanced_settings[\"remove_unrelaxed_trajectory\"]:\n", + " os.remove(trajectory_pdb)\n", + "\n", + " # measure time it took to generate designs for one trajectory\n", + " design_time = time.time() - design_start_time\n", + " design_time_text = f\"{'%d hours, %d minutes, %d seconds' % (int(design_time // 3600), int((design_time % 3600) // 60), int(design_time % 60))}\"\n", + " print(\"Design and validation of trajectory \"+design_name+\" took: \"+design_time_text)\n", + "\n", + " # analyse the rejection rate of trajectories to see if we need to readjust the design weights\n", + " if trajectory_n >= advanced_settings[\"start_monitoring\"] and advanced_settings[\"enable_rejection_check\"]:\n", + " acceptance = accepted_designs / trajectory_n\n", + " if not acceptance >= advanced_settings[\"acceptance_rate\"]:\n", + " print(\"The ratio of successful designs is lower than defined acceptance rate! Consider changing your design settings!\")\n", + " print(\"Script execution stopping...\")\n", + " break\n", + "\n", + " # increase trajectory number\n", + " trajectory_n += 1\n", + "\n", + " # Colab-specific: update counters\n", + " num_sampled_trajectories = len(pd.read_csv(trajectory_csv))\n", + " num_accepted_designs = len(pd.read_csv(final_csv))\n", + " sampled_trajectories_label.value = f\"Sampled trajectories: {num_sampled_trajectories}\"\n", + " accepted_designs_label.value = f\"Accepted designs: {num_accepted_designs}\"\n", + "\n", + "### Script finished\n", + "elapsed_time = time.time() - script_start_time\n", + "elapsed_text = f\"{'%d hours, %d minutes, %d seconds' % (int(elapsed_time // 3600), int((elapsed_time % 3600) // 60), int(elapsed_time % 60))}\"\n", + "print(\"Finished all designs. Script execution for \"+str(trajectory_n)+\" trajectories took: \"+elapsed_text)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jdmYnBypaUHR" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "U5qKDGGabhZo" + }, + "outputs": [], + "source": [ + "#@title Consolidate & Rank Designs\n", + "#@markdown ---\n", + "accepted_binders = [f for f in os.listdir(design_paths[\"Accepted\"]) if f.endswith('.pdb')]\n", + "\n", + "for f in os.listdir(design_paths[\"Accepted/Ranked\"]):\n", + " os.remove(os.path.join(design_paths[\"Accepted/Ranked\"], f))\n", + "\n", + "# load dataframe of designed binders\n", + "design_df = pd.read_csv(mpnn_csv)\n", + "design_df = design_df.sort_values('Average_i_pTM', ascending=False)\n", + "\n", + "# create final csv dataframe to copy matched rows, initialize with the column labels\n", + "final_df = pd.DataFrame(columns=final_labels)\n", + "\n", + "# check the ranking of the designs and copy them with new ranked IDs to the folder\n", + "rank = 1\n", + "for _, row in design_df.iterrows():\n", + " for binder in accepted_binders:\n", + " target_settings[\"binder_name\"], model = binder.rsplit('_model', 1)\n", + " if target_settings[\"binder_name\"] == row['Design']:\n", + " # rank and copy into ranked folder\n", + " row_data = {'Rank': rank, **{label: row[label] for label in design_labels}}\n", + " final_df = pd.concat([final_df, pd.DataFrame([row_data])], ignore_index=True)\n", + " old_path = os.path.join(design_paths[\"Accepted\"], binder)\n", + " new_path = os.path.join(design_paths[\"Accepted/Ranked\"], f\"{rank}_{target_settings['binder_name']}_model{model.rsplit('.', 1)[0]}.pdb\")\n", + " shutil.copyfile(old_path, new_path)\n", + "\n", + " rank += 1\n", + " break\n", + "\n", + "# save the final_df to final_csv\n", + "final_df.to_csv(final_csv, index=False)\n", + "\n", + "print(\"Designs ranked and final_designs_stats.csv generated\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ro9PQBE9zoIw" + }, + "outputs": [], + "source": [ + "#@title Top 20 Designs\n", + "df = pd.read_csv(os.path.join(design_path, 'final_design_stats.csv'))\n", + "df.head(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qOtpEzCbzoX8" + }, + "outputs": [], + "source": [ + "#@title Top Design Display\n", + "import py3Dmol\n", + "import glob\n", + "from IPython.display import HTML\n", + "\n", + "#### pymol top design\n", + "top_design_dir = os.path.join(design_path, 'Accepted', 'Ranked')\n", + "top_design_pdb = glob.glob(os.path.join(top_design_dir, '1_*.pdb'))[0]\n", + "\n", + "# Visualise in PyMOL\n", + "view = py3Dmol.view()\n", + "view.addModel(open(top_design_pdb, 'r').read(),'pdb')\n", + "view.setBackgroundColor('white')\n", + "view.setStyle({'chain':'A'}, {'cartoon': {'color':'#3c5b6f'}})\n", + "view.setStyle({'chain':'B'}, {'cartoon': {'color':'#B76E79'}})\n", + "view.zoomTo()\n", + "view.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "qX0E849cdpTv" + }, + "outputs": [], + "source": [ + "#@title Display animation\n", + "import glob\n", + "from IPython.display import HTML\n", + "\n", + "#### pymol top design\n", + "top_design_dir = os.path.join(design_path, 'Accepted', 'Ranked')\n", + "top_design_pdb = glob.glob(os.path.join(top_design_dir, '1_*.pdb'))[0]\n", + "\n", + "top_design_name = os.path.basename(top_design_pdb).split('1_', 1)[1].split('_mpnn')[0]\n", + "top_design_animation = os.path.join(design_path, 'Accepted', 'Animation', f\"{top_design_name}.html\")\n", + "\n", + "# Show animation\n", + "HTML(top_design_animation)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } \ No newline at end of file