Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
*.pyc
5 changes: 4 additions & 1 deletion bindcraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
script_start_time = time.time()
trajectory_n = 1
accepted_designs = 0
trajectory_runtime = TrajectoryDesignRuntime(advanced_settings)

### start design loop
while True:
Expand Down Expand Up @@ -108,7 +109,8 @@
### Begin binder hallucination
trajectory = binder_hallucination(design_name, target_settings["starting_pdb"], target_settings["chains"],
target_settings["target_hotspot_residues"], length, seed, helicity_value,
design_models, advanced_settings, design_paths, failure_csv)
design_models, advanced_settings, design_paths, failure_csv,
runtime=trajectory_runtime)
trajectory_metrics = copy_dict(trajectory._tmp["best"]["aux"]["log"]) # contains plddt, ptm, i_ptm, pae, i_pae
trajectory_pdb = os.path.join(design_paths["Trajectory"], design_name + ".pdb")

Expand Down Expand Up @@ -174,6 +176,7 @@

### MPNN redesign of starting binder
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
trajectory_runtime.invalidate()
Comment on lines 178 to +179
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mpnn_gen_sequence() calls clear_mem() internally, which can free the JAX buffers backing trajectory_runtime.af_model while trajectory_runtime still holds a reference to it. To avoid a transient invalid cached state (and make failure modes safer), consider calling trajectory_runtime.invalidate() immediately before invoking mpnn_gen_sequence() (or wrap the call in a try/finally that invalidates even if MPNN generation errors).

Suggested change
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
trajectory_runtime.invalidate()
trajectory_runtime.invalidate()
try:
mpnn_trajectories = mpnn_gen_sequence(trajectory_pdb, binder_chain, trajectory_interface_residues, advanced_settings)
finally:
trajectory_runtime.invalidate()

Copilot uses AI. Check for mistakes.
existing_mpnn_sequences = set(pd.read_csv(mpnn_csv, usecols=['Sequence'])['Sequence'].values)

# create set of MPNN sequences with allowed amino acid composition
Expand Down
223 changes: 177 additions & 46 deletions functions/colabdesign_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,53 +17,182 @@
from .pyrosetta_utils import pr_relax, align_pdbs
from .generic_utils import update_failures

# hallucinate a binder
def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residues, length, seed, helicity_value, design_models, advanced_settings, design_paths, failure_csv):
model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb")
_DEFAULT_TRAJECTORY_RUNTIME = None


class TrajectoryDesignRuntime:
"""Reusable AF2 trajectory runtime to avoid per-trajectory model rebuilds."""

def __init__(self, advanced_settings):
self.advanced_settings = advanced_settings
self.af_model = None
self._rg_loss_added = False
self._iptm_loss_added = False
self._termini_loss_added = False
self._helix_loss_added = False

def _ensure_model(self):
"""
(Re)initialize the AF2 design model if needed.

The model is rebuilt when ``af_model`` is ``None`` – either on first
use or after an explicit ``invalidate()`` call. External calls to
``clear_mem()`` (e.g. during MPNN validation) free the underlying
JAX device buffers, so callers **must** call ``invalidate()`` after
any such ``clear_mem()`` invocation to ensure a fresh model is
created on the next trajectory.
"""
if self.af_model is not None:
return

clear_mem()
self.af_model = mk_afdesign_model(
protocol="binder",
debug=False,
data_dir=self.advanced_settings["af_params_dir"],
use_multimer=self.advanced_settings["use_multimer_design"],
num_recycles=self.advanced_settings["num_recycles_design"],
best_metric="loss",
)
print("Initialised reusable AF2 trajectory model.")

def invalidate(self):
"""Mark the cached model as stale so ``_ensure_model`` rebuilds it.

Call this after any external ``clear_mem()`` invocation that may
have freed the JAX device buffers backing ``self.af_model``.
"""
self.af_model = None
self._rg_loss_added = False
self._iptm_loss_added = False
self._termini_loss_added = False
self._helix_loss_added = False

def prepare_trajectory(
self,
starting_pdb,
chain,
target_hotspot_residues,
length,
seed,
helicity_value,
):
self._ensure_model()

hotspot = None if target_hotspot_residues == "" else target_hotspot_residues

self.af_model.prep_inputs(
pdb_filename=starting_pdb,
chain=chain,
binder_len=length,
hotspot=hotspot,
seed=seed,
rm_aa=self.advanced_settings["omit_AAs"],
rm_target_seq=self.advanced_settings["rm_template_seq_design"],
rm_target_sc=self.advanced_settings["rm_template_sc_design"],
)
# Reset recycles for each run in case prior runs changed it (e.g. 4stage beta optimisation).
self.af_model.set_opt(num_recycles=self.advanced_settings["num_recycles_design"])

self.af_model.opt["weights"].update(
{
"pae": self.advanced_settings["weights_pae_intra"],
"plddt": self.advanced_settings["weights_plddt"],
"i_pae": self.advanced_settings["weights_pae_inter"],
"con": self.advanced_settings["weights_con_intra"],
"i_con": self.advanced_settings["weights_con_inter"],
}
)

# Redefine intramolecular contacts (con) and intermolecular contacts (i_con) definitions.
self.af_model.opt["con"].update(
{
"num": self.advanced_settings["intra_contact_number"],
"cutoff": self.advanced_settings["intra_contact_distance"],
"binary": False,
"seqsep": 9,
}
)
self.af_model.opt["i_con"].update(
{
"num": self.advanced_settings["inter_contact_number"],
"cutoff": self.advanced_settings["inter_contact_distance"],
"binary": False,
}
)

# clear GPU memory for new trajectory
clear_mem()
if self.advanced_settings["use_rg_loss"]:
if not self._rg_loss_added:
add_rg_loss(self.af_model, self.advanced_settings["weights_rg"])
self._rg_loss_added = True
else:
self.af_model.opt["weights"]["rg"] = self.advanced_settings["weights_rg"]
elif self._rg_loss_added:
self.af_model.opt["weights"]["rg"] = 0.0

if self.advanced_settings["use_i_ptm_loss"]:
if not self._iptm_loss_added:
add_i_ptm_loss(self.af_model, self.advanced_settings["weights_iptm"])
self._iptm_loss_added = True
else:
self.af_model.opt["weights"]["i_ptm"] = self.advanced_settings["weights_iptm"]
elif self._iptm_loss_added:
self.af_model.opt["weights"]["i_ptm"] = 0.0

if self.advanced_settings["use_termini_distance_loss"]:
if not self._termini_loss_added:
add_termini_distance_loss(
self.af_model, self.advanced_settings["weights_termini_loss"]
)
self._termini_loss_added = True
else:
self.af_model.opt["weights"]["NC"] = self.advanced_settings[
"weights_termini_loss"
]
elif self._termini_loss_added:
self.af_model.opt["weights"]["NC"] = 0.0

if not self._helix_loss_added:
add_helix_loss(self.af_model, helicity_value)
self._helix_loss_added = True
else:
self.af_model.opt["weights"]["helix"] = helicity_value

# initialise binder hallucination model
af_model = mk_afdesign_model(protocol="binder", debug=False, data_dir=advanced_settings["af_params_dir"],
use_multimer=advanced_settings["use_multimer_design"], num_recycles=advanced_settings["num_recycles_design"],
best_metric='loss')

# sanity check for hotspots
if target_hotspot_residues == "":
target_hotspot_residues = None

af_model.prep_inputs(pdb_filename=starting_pdb, chain=chain, binder_len=length, hotspot=target_hotspot_residues, seed=seed, rm_aa=advanced_settings["omit_AAs"],
rm_target_seq=advanced_settings["rm_template_seq_design"], rm_target_sc=advanced_settings["rm_template_sc_design"])

### Update weights based on specified settings
af_model.opt["weights"].update({"pae":advanced_settings["weights_pae_intra"],
"plddt":advanced_settings["weights_plddt"],
"i_pae":advanced_settings["weights_pae_inter"],
"con":advanced_settings["weights_con_intra"],
"i_con":advanced_settings["weights_con_inter"],
})

# redefine intramolecular contacts (con) and intermolecular contacts (i_con) definitions
af_model.opt["con"].update({"num":advanced_settings["intra_contact_number"],"cutoff":advanced_settings["intra_contact_distance"],"binary":False,"seqsep":9})
af_model.opt["i_con"].update({"num":advanced_settings["inter_contact_number"],"cutoff":advanced_settings["inter_contact_distance"],"binary":False})

return self.af_model

### additional loss functions
if advanced_settings["use_rg_loss"]:
# radius of gyration loss
add_rg_loss(af_model, advanced_settings["weights_rg"])

if advanced_settings["use_i_ptm_loss"]:
# interface pTM loss
add_i_ptm_loss(af_model, advanced_settings["weights_iptm"])
# hallucinate a binder
def binder_hallucination(
design_name,
starting_pdb,
chain,
target_hotspot_residues,
length,
seed,
helicity_value,
design_models,
advanced_settings,
design_paths,
failure_csv,
runtime=None,
):
global _DEFAULT_TRAJECTORY_RUNTIME

model_pdb_path = os.path.join(design_paths["Trajectory"], design_name+".pdb")

if advanced_settings["use_termini_distance_loss"]:
# termini distance loss
add_termini_distance_loss(af_model, advanced_settings["weights_termini_loss"])
if runtime is None:
if _DEFAULT_TRAJECTORY_RUNTIME is None:
_DEFAULT_TRAJECTORY_RUNTIME = TrajectoryDesignRuntime(advanced_settings)
runtime = _DEFAULT_TRAJECTORY_RUNTIME
Comment on lines +183 to +186
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module-level _DEFAULT_TRAJECTORY_RUNTIME caches a single TrajectoryDesignRuntime instance across calls. If binder_hallucination() is invoked later in the same Python process with a different advanced_settings (e.g., running multiple campaigns in one session), the cached runtime will silently reuse the old settings/model, which can produce incorrect behavior and also keep GPU memory alive longer than intended. Prefer creating a new runtime per call when runtime is not provided, or key the cache by a stable identifier (e.g., settings path/hash) and provide an explicit reset/close API.

Copilot uses AI. Check for mistakes.

# add the helicity loss
add_helix_loss(af_model, helicity_value)
af_model = runtime.prepare_trajectory(
starting_pdb=starting_pdb,
chain=chain,
target_hotspot_residues=target_hotspot_residues,
length=length,
seed=seed,
helicity_value=helicity_value,
)

# calculate the number of mutations to do based on the length of the protein
greedy_tries = math.ceil(length * (advanced_settings["greedy_percentage"] / 100))
Expand Down Expand Up @@ -102,6 +231,8 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
# if best iteration has high enough confidence then continue
if initial_plddt > 0.65:
print("Initial trajectory pLDDT good, continuing: "+str(initial_plddt))
soft_iterations = advanced_settings["soft_iterations"]
temporary_iterations = advanced_settings["temporary_iterations"]
if advanced_settings["optimise_beta"]:
# temporarily dump model to assess secondary structure
af_model.save_pdb(model_pdb_path)
Expand All @@ -110,13 +241,13 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu

# if beta sheeted trajectory is detected then choose to optimise
if float(beta) > 15:
advanced_settings["soft_iterations"] = advanced_settings["soft_iterations"] + advanced_settings["optimise_beta_extra_soft"]
advanced_settings["temporary_iterations"] = advanced_settings["temporary_iterations"] + advanced_settings["optimise_beta_extra_temp"]
soft_iterations = soft_iterations + advanced_settings["optimise_beta_extra_soft"]
temporary_iterations = temporary_iterations + advanced_settings["optimise_beta_extra_temp"]
af_model.set_opt(num_recycles=advanced_settings["optimise_beta_recycles_design"])
print("Beta sheeted trajectory detected, optimising settings")

# how many logit iterations left
logits_iter = advanced_settings["soft_iterations"] - 50
logits_iter = soft_iterations - 50
if logits_iter > 0:
print("Stage 1: Additional Logits Optimisation")
af_model.clear_best()
Expand All @@ -129,10 +260,10 @@ def binder_hallucination(design_name, starting_pdb, chain, target_hotspot_residu
logit_plddt = initial_plddt

# perform softmax trajectory design
if advanced_settings["temporary_iterations"] > 0:
if temporary_iterations > 0:
print("Stage 2: Softmax Optimisation")
af_model.clear_best()
af_model.design_soft(advanced_settings["temporary_iterations"], e_temp=1e-2, models=design_models, num_models=1,
af_model.design_soft(temporary_iterations, e_temp=1e-2, models=design_models, num_models=1,
sample_models=advanced_settings["sample_models"], ramp_recycles=False, save_best=True)
softmax_plddt = get_best_plddt(af_model, length)
else:
Expand Down
Loading
Loading