Pyrosetta for RFdiffusion

I will not lie: I often struggle to find a snippet of code that did something in PyRosetta or I spend hours facing a problem caused by something not working as I expect it to. I recently did a tricky project involving RFdiffusion and I kept slipping on the PyRosetta side. So to make future me, others, and ChatGTP5 happy, here are some common operations to make working with PyRosetta for RFdiffusion easier.

Import

For easy copypasting, let’s start PyRosetta and import things. As said in previous post I am not keen on star imports, because they are not meant to be used indiscriminately and can cause issues —a common pythonic example is when someone mixes collections.Counter and typing.Counter.

import pyrosetta
import pyrosetta_help as ph
from types import ModuleType
# Better than star imports:
prc: ModuleType = pyrosetta.rosetta.core
prp: ModuleType = pyrosetta.rosetta.protocols
pru: ModuleType = pyrosetta.rosetta.utility
prn: ModuleType = pyrosetta.rosetta.numeric
pr_conf: ModuleType = pyrosetta.rosetta.core.conformation
pr_scoring: ModuleType = pyrosetta.rosetta.core.scoring
pr_res: ModuleType = pyrosetta.rosetta.core.select.residue_selector
pr_options: ModuleType = pyrosetta.rosetta.basic.options

logger = ph.configure_logger()
pyrosetta.distributed.maybe_init(extra_options=ph.make_option_string(no_optH=False,
                                                                     ex1=None,
                                                                     ex2=None,
                                                                     ignore_unrecognized_res=True,
                                                                     load_PDB_components=False,
                                                                     ignore_waters=True,
                                                                    )
                                 )

Load from a PDB file:

pyrosetta.pose_from_file('👾👾👾.pdb')

or from string

pose = pyrosetta.Pose()
prc.import_pose.pose_from_pdbstring(pose, 'ATOM 👾👾👾...')

or from a mmCIF file. Note this will not work straight from PyMol because an extra entry at end is needed as discussed previously. I should say working with mmCIF is uncommon, but does have some nice metadata handling advantages.

pose: pyrosetta.Pose = prc.import_pose.pose_from_file('👾👾👾.cif', read_fold_tree=True, type=prc.import_pose.FileType.CIF_file)

There are several other pose importers. Another useful one, for threading especially, is pyrosetta.pose_from_sequence is rather helpful —just watch out for the eye of Sauron when threading.

Extra complication happens when dealing with residue types that are not in the database. These need to be added to the pose before loading the file.

As always with PyRosetta, if you pass an illegal value, say prc.import_pose.pose_from_file(Exception) you’ll get a TypeError that will tell you the accepted arguments of the overloaded function, which sometimes differs from the help(fun) info.

A new residue type can be added from a params file or from memory. The official way to create a params file is from the Python 2.7 script from the Rosetta download, but that was not good for my purposes so I wrote a new converter, the rdkit_to_params module, which I also made into a web app due to popular demand.

A params file can be passed via the initialisation options (-extra_res_fa) or via the following:

pru: ModuleType = pyrosetta.rosetta.utility

params_filenames: List[str] = ...
pose: pyrosetta.Pose = ...
params_vector = pru.vector1_string()
params_vector.extend([f for f in params_filenames if f])
pyrosetta.generate_nonstandard_residue_set(pose, params_vector)

Having added a residue type does not mean a residue was added to the pose. For that you’d need to do:

prc: ModuleType = pyrosetta.rosetta.core

resiset: prc.chemical.ResidueTypeSet = pyrosetta.generate_nonstandard_residue_set(pose, params_vector)
new_res = prc.conformation.ResidueFactory.create_residue( resiset.name_map( name ) )
pose.append_residue_by_jump(new_res, pose.num_jump() + 1)

In the rdkit-to-params module, Params.add_residuetype adds a params but from a string basically.

def add_residuetype(self, pose: pyrosetta.Pose, params_block: str, name3: str, reset:bool=True) \
                                            -> pyrosetta.rosetta.core.chemical.ResidueTypeSet:
        """
        Adds the params to a copy of a residue type set of a pose.
        If reset is True it will also save it as the default RTS —it keeps other custom residue types.
        """
        rts = pose.conformation().modifiable_residue_type_set_for_conf(prc.chemical.FULL_ATOM_t)
        buffer = pyrosetta.rosetta.std.stringbuf(params_block)
        stream = pyrosetta.rosetta.std.istream(buffer)
        new = prc.chemical.read_topology_file(stream,
                                                                 name3,
                                                                 rts)
        rts.add_base_residue_type(new)
        if reset:
            pose.conformation().reset_residue_type_set_for_conf(rts)
        return rts

Connections, either polymeric (LOWER and UPPER) or otherwise (CONN1 etc.) are a topic that would require its own blog post, so I won’t cover it here.

PDBInfo

PDB information (residue numbering, chain, occupancy and more) is stored in a pyrosetta.rosetta.core.pose.PDBInfo instance attached to the pose (accessible via .pdb_info() method, which returns the instance attached and not a copy).

This very handy. Two of its methods are pose2pdb(res=👾) and pdb2pose(chain=👾, res=👾) which allows one to jump between the two standards. A nice thing is that it does not segfault when a pose residue index that does not exist is called (RuntimeError). If the a PDB residue is request that does not exist it will return 0.

A parenthetical warning — one of its methods is .remarks(), which is a no go. In the PDB format, REMARK is basically a comment that has really strict tools-specific grammar and is a total mess. The .remarks() does not work in PyRosetta and segfaults even when the file loaded did indeed have REMARK lines. And passing a pyrosetta.rosetta.core.io.Remarks instance will not work nor will appending to a Remarks or even making one:

remarks = pyrosetta.rosetta.core.io.Remarks()
remark = pyrosetta.rosetta.core.io.RemarkInfo()
remark.value = 'Hello world'
remarks.append(remark) # ValueError: vector

So were metadata needed to be added to a file, editing the PDB block might be a better bet —editing the mmCIF dictionary data is even betterer. In a pinch appending comment lines to a PDB helps too (but don’t tell anyone).

Certain operations (RemodelMover, grafting etc.) will make the PDB information obsolete or make it get lost.

The Pose class has a method split_by_chain which returns the chains (as defined in the fold tree) preserving the PDBInfo, but the function (external to pyrosetta.Pose) pyrosetta.rosetta.core.pose.append_pose_to_pose does not, so it needs correcting:

def add_chain(built: pyrosetta.Pose, new: pyrosetta.Pose, reset: bool = False) -> None:
    """
    Add a chain ``new`` to a pose ``built`` preserving the residue numbering.

    :param built: this is the pyrosetta.Pose that will be built into...
    :param new: the addendum
    :param reset: resets the PDBInfo for the chain present to A
    :return:
    """
    built_pi = built.pdb_info()
    if built_pi is None or reset:
        built_pi = prc.pose.PDBInfo(built)
        built.pdb_info(built_pi)
        for r in range(1, built.total_residue() + 1):
            built_pi.set_resinfo(res=r, chain_id='A', pdb_res=r)
    for chain in new.split_by_chain():
        offset: int = built.total_residue()
        pyrosetta.rosetta.core.pose.append_pose_to_pose(built, chain, new_chain=True)
        # the new `built` residues will not have PDBinfo
        chain_pi = chain.pdb_info()
        for r in range(1, chain.total_residue() + 1):
            built_pi.set_resinfo(res=r + offset, chain_id=chain_pi.chain(r), pdb_res=chain_pi.number(r))
    built_pi.obsolete(False)

The functions pyrosetta.rosetta.protocols.grafting.delete_region or pyrosetta.rosetta.protocols.grafting.return_region do not impact the PDBinfo.

Alignment and superposition

First, a silly pedantic lexical side note. When one has a sequence of elements and one spaces out the elements so they match, that is called to align. When one places one object over another but letting one still be visible, say with a stagger like the traces in a joy plot or translating without rotating (e.g. ▽+△=✡︎) that is called to superimpose, when one rototranslates fully (e.g. ▽+△↺=▽), that is called to superpose. Here I will align sequences, to superposed structures.

In PyRosetta one can superpose two poses with the following:

mobile: pyrosetta.Pose = ...
ref: pyrosetta.Pose = ...
atom_map: prs.map_core_id_AtomID_core_id_AtomID = ... # this is a hash-mapping of mobile `pyrosetta.AtomID` to reference `pyrosetta.AtomID`
rmsd: float = pr_scoring.superimpose_pose(mod_pose=mobile, ref_pose=ref, atom_map=atom_map)

There are actually several functions that allow superpositions, including presets for CA atoms or all atoms and so forth. The above is customisable as you have to tell it what atom goes to which.

So where we to align two poses by the CA atoms of a specified PDB chain that ought to be common, we can do:

prc: ModuleType = pyrosetta.rosetta.core
prp: ModuleType = pyrosetta.rosetta.protocols
pru: ModuleType = pyrosetta.rosetta.utility
prn: ModuleType = pyrosetta.rosetta.numeric
prs: ModuleType = pyrosetta.rosetta.std 
pr_conf: ModuleType = pyrosetta.rosetta.core.conformation
pr_scoring: ModuleType = pyrosetta.rosetta.core.scoring
pr_res: ModuleType = pyrosetta.rosetta.core.select.residue_selector
pr_options: ModuleType = pyrosetta.rosetta.basic.options

def superpose_pose_by_chain(pose, ref, chain: str, strict: bool=True) -> float:
    """
    superpose by PDB chain letter

    :param pose:
    :param ref:
    :param chain:
    :return:
    """
    atom_map = prs.map_core_id_AtomID_core_id_AtomID()
    chain_sele: pr_res.ResidueSelector = pr_res.ChainSelector(chain)
    r: int  # reference pose residue number (Fortran)
    m: int  # mobile pose residue number (Fortran)
    for r, m in zip(pr_res.selection_positions(chain_sele.apply(ref)),
                    pr_res.selection_positions(chain_sele.apply(pose))
                    ):
        if strict:
            assert pose.residue(m).name3() == ref.residue(r).name3(), 'Mismatching residue positions!'
        ref_atom = pyrosetta.AtomID(ref.residue(r).atom_index("CA"), r)
        mobile_atom = pyrosetta.AtomID(pose.residue(m).atom_index("CA"), m)
        atom_map[mobile_atom] = ref_atom
    return pr_scoring.superimpose_pose(mod_pose=pose, ref_pose=ref, atom_map=atom_map)

If pyrosetta.AtomID seems familiar, that is because for constraints one has to play around with them a lot.

Sequence alignment comes into play when superposing protein by the conserved regions. For that Bio.Align.PairwiseAligner can be used (since pairwise2 was deprecated).

In the default settings, the character ‘-‘ in the input sequence is treated like a regular character and not a free gap. Were one hellbent to do so a custom substitution matrix would be needed.

Were one to want align to homologues, this would be fine:

from Bio.Align import PairwiseAligner, Alignment, substitution_matrices

def superpose_by_alignment(mobile: pyrosetta.Pose, ref: pyrosetta.Pose) -> float:
    """
    Pairwise alignment of the sequences of the poses.

    return  (ref_index, mobile_index)
    :param mobile:
    :param ref:
    :return:
    """
    # ## align
    aligner = PairwiseAligner()
    aligner.substitution_matrix = substitution_matrices.load("BLOSUM62")
    ref_seq: str = ref.sequence()
    pose_seq: str = mobile.sequence()
    aln: Alignment = aligner.align(ref_seq, pose_seq)[0]
    # `aln.indices` has the mapping
    # an index of -1 is a map to a gap
    aln_map: Dict[int, int] = {t: q for t, q in zip(aln.indices[0], aln.indices[1]) if
               q != -1 and t != -1}
    # ## make pyrosetta atom map
    # for the purpose of explanation the following block is not in its own function, but will be repeated in a minute.
    # where these all part of the same code, the repeated part would need to be its own function as it is very modular anyway!
    atom_map = prs.map_core_id_AtomID_core_id_AtomID()
    for r, m in aln_map.items():
        ref_atom = pyrosetta.AtomID(ref.residue(r + offset).atom_index("CA"), r + offset)
        mobile_atom = pyrosetta.AtomID(mobile.residue(m + offset).atom_index("CA"), m + offset)
        atom_map[mobile_atom] = ref_atom
    # ## superpose and return RMSD
    return pr_scoring.superimpose_pose(mod_pose=mobile, ref_pose=ref, atom_map=atom_map)

However, if the poses to superpose differed by a span that any identity is coincidental, such as a redesigned part of a domain, then we need to strongly penalise mismatches and not penalise gap extensions.

from Bio.Align import PairwiseAligner, Alignment, substitution_matrices

def superpose_by_common_alignment(mobile: pyrosetta.Pose, ref: pyrosetta.Pose) -> float:
    """
    Pairwise alignment of the sequences of the poses.
    This time only spans that are common.

    return  (ref_index, mobile_index)
    :param mobile:
    :param ref:
    :return:
    """
    # ## align
    aligner = PairwiseAligner()
    # We don't want to get mismatches aligned, so no to:
    # aligner.substitution_matrix = substitution_matrices.load("BLOSUM62")
    aligner.internal_gap_score = -10
    aligner.extend_gap_score = -0.01
    aligner.end_gap_score = -0.01
    ref_seq: str = ref.sequence()
    pose_seq: str = mobile.sequence()
    aln: Alignment = aligner.align(ref_seq, pose_seq)[0]
    # like before but with an extra condition `ref_seq[t] == pose_seq[q]`
    aln_map: Dict[int, int] = {t: q for t, q in zip(aln.indices[0], aln.indices[1]) if
               q != -1 and t != -1 and ref_seq[t] == pose_seq[q]}
    # ## make pyrosetta atom map
    ...
    # return RMSD
    return pr_scoring.superimpose_pose(mod_pose=mobile, ref_pose=ref, atom_map=atom_map)

In the case that one end does not matter in the sequence, there are attributes that control the left (N-terminal) an right (C-terminal) side gaps. For example, the attributes target_right_gap_score and target_right_extend_gap_score could be set to zero, which make the C-terminal difference an inkshed of gaps.

So what operations move the pose? A lot. The main two to keep an eye out for are RFdiffusion itself and Relax and its pesky drift.

Thread

In PyRosetta one can mutate a residue via

prp: ModuleType = pyrosetta.rosetta.protocols
prp.simple_moves.MutateResidue(1, 'ALA').apply(pose)

This allows mutations to not only base residues, but also patched residues, such as NtermProteinFull (with extra N-terminal proton) or acetylated etc. To see if in the database folder there is what you are after it is honestly easier simply looking in the folder:

import pyrosetta, itertools
from pathlib import Path
from itertools import chain
fa = 'database/chemical/residue_type_sets/fa_standard'

# residue types
print([p.name for p in (Path(pyrosetta.__file__).parent / fa / 'residue_types' ).glob('*/*.params')])

# params types
patches_folder = (Path(pyrosetta.__file__).parent / fa / 'patches' )
print([p.name for p in itertools.chain(patches_folder.glob('*.txt'), patches_folder.glob('*/*.txt'))])


Just be aware, that MutateResidue segfaults on error.

To remove the terminus patches one can do the following:

import pyrosetta
prc: ModuleType = pyrosetta.rosetta.core
pru: ModuleType = pyrosetta.rosetta.utility

def remove_terminus_patches(pose):
    clean_template_pose = template_pose.clone()
    prc.pose.remove_nonprotein_residues(pose)
    ### find
    lowers = pru.vector1_std_pair_unsigned_long_protocols_sic_dock_Vec3_t()
    uppers = pru.vector1_std_pair_unsigned_long_protocols_sic_dock_Vec3_t()
    prc.sic_dock.get_termini_from_pose(pose, lowers, uppers)
    ### remove
    for upper, _ in uppers:
        prc.conformation.remove_upper_terminus_type_from_conformation_residue(clean_template_pose.conformation(), upper)
    for lower, _ in lowers:
        prc.conformation.remove_lower_terminus_type_from_conformation_residue(clean_template_pose.conformation(), lower)

However, mutating a whole pose based on a sequence is a pillar of homology modelling, namely threading. I have discussed in an other post threading and hybridisation in PyRosetta and most examples here are simply ported from a helper module of mine, so I’ll be brief. Threading in Rosetta requires a “Grishin file”, which is a specific form of pairwise sequence alignment file. The mover is ThreadingMover with the aid of StealSideChainsMover, namely:

prc: ModuleType = pyrosetta.rosetta.core
prp: ModuleType = pyrosetta.rosetta.protocols

clean_template_pose = remove_terminus_patches(pose) # no termini from earlier!
target_pose = pyrosetta.Pose()
prc.pose.make_pose_from_sequence(target_pose, target_sequence, 'fa_standard')

## Thread
align = prc.sequence.read_aln(format='grishin', filename=aln_file)[1]
threader = prp.comparative_modeling.ThreadingMover(align=align, template_pose=clean_template_pose)
threader.apply(target_pose)
## Steal sidechains
qt = threader.get_qt_mapping(target_pose)
steal = prp.comparative_modeling.StealSideChainsMover(clean_template_pose, qt)
steal.apply(target_pose)

The main point of the aforementioned post about threading was loop modelling by cannibalising AF2, which can be done with the following settings before using the ThreadingMover mover:

# optional set for loop modelling from reference poses
poses: List[pyrosetta.Pose] = ... # poses to cannibilise for loops
lengths: List[int] = [3,] # 3,6,9 are traditional choices
fragment_sets = pru.vector1_std_shared_ptr_core_fragment_FragSet_t(len(lengths))
for i, l in enumerate(lengths):
        fragment_sets[i+1] = prc.fragment.ConstantLengthFragSet(l)
        for pose in poses:
            prc.fragment.steal_constant_length_frag_set_from_pose(pose, fragment_sets[i+1])

threader = ...
threader.build_loops(True)
threader.randomize_loop_coords(True)  # default
threader.frag_libs(fragment_sets)
threader.apply(...)

As may be apparent, the ThreadingMover does not like non-base residues. So for ligand residues, one can just use append_subpose_to_pose function.

prc: ModuleType = pyrosetta.rosetta.core
pr_res: ModuleType = pyrosetta.rosetta.core.select.residue_selector

def steal_ligands(donor_pose, acceptor_pose) -> None:
    """
    Steals non-Protein residues from donor_pose and adds them to acceptor_pose

    Do not use with nucleic acid polymers.

    :param donor_pose:
    :param acceptor_pose:
    :return:
    """
    PROTEIN = prc.chemical.ResidueProperty.PROTEIN
    prot_sele = pr_res.ResiduePropertySelector(PROTEIN)
    not_sele = pr_res.NotResidueSelector(prot_sele)
    rv = pr_res.ResidueVector(not_sele.apply(donor_pose))
    # if it were DNA...
    # for from_res, to_res in ph.rangify(rv):
    #     prc.pose.append_subpose_to_pose(acceptor_pose, donor_pose, from_res, to_res, True)
    for res in rv:
        prc.pose.append_subpose_to_pose(acceptor_pose, donor_pose, res, res, True)

Now back to threading. As mentioned, I just use my helper function to deal with the above:

import pyrosetta_help

def thread(template_block, target_seq, target_name, template_name,
           temp_folder='/tmp'):
    # load template
    template = pyrosetta.Pose()
    prc.import_pose.pose_from_pdbstring(template, template_block)
    # thread
    aln_filename = f'{temp_folder}/{template_name}-{target_name}.grishin'
    ph.write_grishin(target_name=target_name,
                     target_sequence=target_seq,
                     template_name=template_name,
                     template_sequence=template.sequence(),
                     outfile=aln_filename
                     )
    aln: prc.sequence.SequenceAlignment = prc.sequence.read_aln(format='grishin', filename=aln_filename)[1]
    threaded: pyrosetta.Pose
    threader: prp.comparative_modeling.ThreadingMover
    threadites: pru.vector1_bool
    threaded, threader, threadites = ph.thread(target_sequence=target_seq,
                                               template_pose=template,
                                               target_name=target_name,
                                               template_name=template_name,
                                               align=aln
                                               )
    # no need to superpose. It is already aligned
    # superpose(template, threaded)
    # fix pdb info
    n = threaded.total_residue()
    pi = prc.pose.PDBInfo(n)
    for i in range(1, n + 1):
        pi.number(i, i)
        pi.chain(i, 'A')
    threaded.pdb_info(pi)
    return threaded

A pythonic parenthesis. Names of callable objects are, by good practice, verbs, however stuff gets confusing when a word can be both a noun and a verb: I have tripped up so often with thread and fragment. Hence the use of the past particle in the above. For fragment, which has initial-stress derivation I naughtily use fràgment and fragmènt.

Relax

The workhorse of Rosetta is without question the FastRelax mover. It is very customisable, so most often there is no near to play with packers, Monte Carlo samplers etc. This is true for design mode. I won’t go into detail on the basics as this is covered in every tutorial.

scorefxn: pr_scoring.ScoreFunction = pyrosetta.get_fa_scorefxn()
# don't forget to set weights!
scorefxn.set_weight(pr_scoring.ScoreType.atom_pair_constraint, 3)
scorefxn.set_weight(pr_scoring.ScoreType.coordinate_constraint, 5)
cycles = 3  # more than 5 is probably not going to do much
relax = prp.relax.FastRelax(scorefxn, cycles)
# set movemap to freeze residues
# say repack the sidechains of chain A and its neighbours 
# allow backbone movements for chain A only
movemap = pyrosetta.MoveMap()
rs: ModuleType = prc.select.residue_selector
chainA_sele: rs.ResidueSelector = rs.ChainSelector('A')
chainA: pru.vector1_bool = chainA_sele.apply(pose)
neigh_sele: rs.ResidueSelector = rs.NeighborhoodResidueSelector(chainA_sele, True, distance)
neighs: pru.vector1_bool = neigh_sele.apply(pose)
movemap.set_chi(neighs)
movemap.set_bb(chainA)
movemap.set_jump(False)
relax.set_movemap(movemap)
# go!
replax.apply(pose)

In the case of RFdiffusion the template ought to have been minimised otherwise either the scores will be for which case was minimised better if the other chains are also minimised (don’t) or the scores will be for cases where the designed binder interacts or not with a spuriously strained part of the binding partner.
In the above example the neighbourhood was moved as an example. If you do repack these sidechains, which is reasonable, do make sure to keep the complex and the original binding partner will no longer work.

Parametrically designed Christmas tree (so old school way)
https://michelanglo.sgc.ox.ac.uk/r/christmas

A cool thing once can do with PyRosetta is use electron design as a constraint (tutorial) —theoretically one could used RFdiffusion via blender API, gemmi and numpy manupulations to make custom shaped protein, but that would be for fun and not actually relevant here.

If there are covalent bond lengths that need fixing, the we have do it cartesian:

constraint_weight = 5
scorefxn: pr_scoring.ScoreFunction = pyrosetta.create_score_function('ref2015_cart')
scorefxn.set_weight(pr_scoring.ScoreType.coordinate_constraint, constraint_weight)
scorefxn.set_weight(pr_scoring.ScoreType.angle_constraint, constraint_weight)
scorefxn.set_weight(pr_scoring.ScoreType.atom_pair_constraint, constraint_weight)

relax = pyrosetta.rosetta.protocols.relax.FastRelax(scorefxn, 3)
relax.cartesian(True)
relax.minimize_bond_angles(True)
relax.minimize_bond_lengths(True)
relax.apply(pose)

In both example I allude to constraints. These are well covered and are set up in various way for example, given the pose index and atom name of the atoms we can do:

def constrain_distance(pose, fore_idx, fore_name, aft_idx, aft_name, x0_in=1.334, sd_in=0.2, tol_in=0.02, weight=1):
    AtomPairConstraint = pr_scoring.constraints.AtomPairConstraint  # noqa
    fore = pyrosetta.AtomID(atomno_in=pose.residue(fore_idx).atom_index(fore_name),
                                rsd_in=fore_idx)
    aft = pyrosetta.AtomID(atomno_in=pose.residue(aft_idx).atom_index(aft_name),
                              rsd_in=aft_idx)
    fun = pr_scoring.func.FlatHarmonicFunc(x0_in=x0_in, sd_in=sd_in, tol_in=tol_in)
    if weight != 1:
        fun = pr_scoring.func.ScalarWeightedFunc(weight, fun)
    con = AtomPairConstraint(fore, aft, fun)
    pose.add_constraint(con)
    return con

def constrain_angle(pose, fore_idx, fore_name, mid_name, mid_idx, aft_idx, aft_name, x0_in=109/180*3.14, sd_in=10/180*3.14, weight=1):
    AngleConstraint = pr_scoring.constraints.AngleConstraint
    fore = pyrosetta.AtomID(atomno_in=pose.residue(fore_idx).atom_index(fore_name),
                                rsd_in=fore_idx)
    mid = pyrosetta.AtomID(atomno_in=pose.residue(mid_idx).atom_index(mid_name),
                                rsd_in=mid_idx)
    aft = pyrosetta.AtomID(atomno_in=pose.residue(aft_idx).atom_index(aft_name),
                              rsd_in=aft_idx)
    fun = pr_scoring.func.CircularHarmonicFunc(x0_radians=x0_in, sd_radians=sd_in)
    if weight != 1:
        fun = pr_scoring.func.ScalarWeightedFunc(weight, fun)
    con = AngleConstraint(fore, mid, aft, fun)
    pose.add_constraint(con)
    return con

def constrain_position(pose: pyrosetta.Pose, target_index: int, ref_index: int, x0_in=0., sd_in=0.01):
    ref_ca = pyrosetta.AtomID(atomno_in=pose.residue(ref_index).atom_index('CA'), rsd_in=ref_index)
    target_ca = pyrosetta.AtomID(atomno_in=pose.residue(target_index).atom_index('CA'), rsd_in=frozen_index)
    target_xyz = pose.residue(frozen_index).xyz(target_ca.atomno())
    fun = pr_scoring.func.HarmonicFunc(x0_in=x0_in, sd_in=sd_in)
    if weight != 1:
        fun = pr_scoring.func.ScalarWeightedFunc(weight, fun)
    con = pr_scoring.constraints.CoordinateConstraint(a1=target_ca, fixed_atom_in=ref_ca, xyz_target_in=target_xyz, func=fun, scotype=pr_scoring.ScoreType.coordinate_constraint)
    pose.add_constraint(con)

def constrain_peptide_gap(pose, chain_break, x0_in=1.334, sd_in=0.2, tol_in=0.02):
    """
    Close a gap between 
    
    """
    AtomPairConstraint = pr_scoring.constraints.AtomPairConstraint  # noqa
    fore_c = pyrosetta.AtomID(atomno_in=pose.residue(chain_break).atom_index('C'),
                                rsd_in=chain_break)
    aft_n = pyrosetta.AtomID(atomno_in=pose.residue(chain_break + 1).atom_index('N'),
                              rsd_in=chain_break + 1)
    fun = pr_scoring.func.FlatHarmonicFunc(x0_in=x0_in, sd_in=sd_in, tol_in=tol_in)
    con = AtomPairConstraint(fore_c, aft_n, fun)
    pose.add_constraint(con)

One thing to keep an eye out for is how healthy are the constraints after minimisation:

def show_cons(pose):
    """
    print the score for the various constraints
    """
    get_atomname = lambda atomid: pose.residue(atomid.rsd()).atom_name(atomid.atomno()).strip()
    get_description = lambda atomid: f'{pose.residue(atomid.rsd()).name3()}{pi.pose2pdb(atomid.rsd()).strip().replace(" ", ":")}.{get_atomname(atomid)}'
    for con in pose.constraint_set().get_all_constraints():
        if con.type() == 'AtomPair':
            print(con.type(), get_description(con.atom1()), get_description(con.atom2()), con.score(pose))
        elif con.type() == 'Angle':
            print(con.type(), get_description(con.atom1()), get_description(con.atom2()), get_description(con.atom3()), con.score(pose))
        else:
            print(con.type(), con.score(pose))

Design

One can do design with FastRelax. Due to the way the mover samples, it fixes sub-optimal residues one by one as opposed to large scale remodelling —for that Remodel might be better suitable (tutorial). Design is really useful for RFdiffusion as the latter can in essence be thought as coarse-grain so there may be some residue combinations that are subpar. Obviously, when FastDesign choses a residue it is based on the scorefunction, which is imperfect: a quirky example of this is its proclivity to change certain residues in Streptag when bound to streptavidin even though the former is the product of rounds and rounds of phage display.

To run design one needs to create a task factory for it onto which restrictions (operations) are pushed. Confusingly, the default is to design everything. The ProhibitSpecifiedBaseResidueTypes is really handy for preventing cysteine from making life hard and also for correcting surface charges —protein are not very soluble if the isoelectric point is close to pH 7.4 or aggregate if there are exposed patches. On the latter note, proteinMPNN when using the vanilla weights potentially might make inadvertently a transmembrane helix, so that is something to keep an eye out for.

def create_design_tf(pose:pyrosetta.Pose, design_sele: pr_res.ResidueSelector, distance:int) -> prc.pack.task.TaskFactory:
    """
    Create design task factory for relax.
    Designs the ``design_sele`` and repacks around ``distance`` of it.

    Remember to do
    
    ... code-block:: python

        relax.set_enable_design(True)
        relax.set_task_factory(task_factory)
    """
    # design is default, so this is not done:
    # residues_to_design = design_sele.apply(pose)
    # design_ops = prc.pack.task.operation.OperateOnResidueSubset(????, residues_to_design)
    no_cys = pru.vector1_std_string(1)
    no_cys[1] = 'CYS'
    no_cys_ops =  prc.pack.task.operation.ProhibitSpecifiedBaseResidueTypes(no_cys)
    # No design, but repack
    repack_sele = pr_res.NeighborhoodResidueSelector(design_sele, distance, False)
    residues_to_repack = repack_sele.apply(pose)
    repack_rtl = prc.pack.task.operation.RestrictToRepackingRLT()
    repack_ops = prc.pack.task.operation.OperateOnResidueSubset(repack_rtl, residues_to_repack)
    # No repack, no design
    frozen_sele = pr_res.NotResidueSelector(pr_res.OrResidueSelector(design_sele, repack_sele))
    residues_to_freeze = frozen_sele.apply(pose)
    prevent_rtl = prc.pack.task.operation.PreventRepackingRLT()
    frozen_ops = prc.pack.task.operation.OperateOnResidueSubset(prevent_rtl, residues_to_freeze)
    # pyrosetta.rosetta.core.pack.task.operation.RestrictAbsentCanonicalAASRLT
    # pyrosetta.rosetta.core.pack.task.operation.PreserveCBetaRLT
    task_factory = prc.pack.task.TaskFactory()
    task_factory.push_back(no_cys_ops)
    task_factory.push_back(repack_ops)
    task_factory.push_back(frozen_ops)
    return task_factory

When I used RFdiffusion I gave the algorithm a range in the contigmap, thus allowing the best length of an insertion to win. This does however make the pipeline a bit complicated. First I had to adapt the helper scripts of ProteinMPNN to be used as Python functions (see footnote). There are then two options for the design step: refer to the original skeleton sequence, where stretches of glycines mark the designed residues, or do a sequence alignment. The former requires being organised, while the latter could use the alignment code discussed above.

def design_different(pose: pyrosetta.Pose, ref: pyrosetta.Pose, cycles = 5, scorefxn=None):
    ref = ref.split_by_chain(1)
    ref2pose: dict = align_for_atom_map(pose.split_by_chain(1), ref)
    conserved = list(ref2pose.values())
    idx_sele = pr_res.ResidueIndexSelector()
    for i in range(1, len(pose.chain_sequence(1))):
        if i not in conserved:
            idx_sele.append_index(i)
    print(idx_sele.apply(pose))
    task_factory: prc.pack.task.TaskFactory = create_design_tf(pose, design_sele=idx_sele, distance=0)
    if scorefxn is None:
        scorefxn = pyrosetta.get_fa_scorefxn()
    relax = pyrosetta.rosetta.protocols.relax.FastRelax(scorefxn, cycles)
    relax.set_enable_design(True)
    relax.set_task_factory(task_factory)
    relax.apply(pose)

The degree of change during a FastDesign run depends on how strained the residues are. So it important that the design be FastRelax properly first and not all at once.
Actually I often use FastRelax in a first pass with the backbone fixed to prevent the model from blowing up.
This degree of change _I think_ makes a good metric for scoring. If the pose.sequence() changed by more than 20%, then that is hard no.

Scoring

RFdiffusion will be made to generate lots of designs. The reason is because they are not all amazing. So they need to be scored.

Coarse-grain clashes

First, we need to determine if the polyglycine skeleton will clash with the extended neighbourhood when the protein is part of a greater machinery that could not be feasibly fed into RFdiffusion. An example is a membrane, say one were designing a binder to a transmembrane protein, the design should not cross the membrane, so one could get the transmembrane protein from OPM database, superpose the skeleton and use the layers of O / N atoms. For a project of mine, I had a multidimensional polymeric lattice. Scoring via a scorefunction and using only the Lenard–Jones repulsion is a tad overkill and prone to errors. So simply doing distance based calculations works:

skeleton: pyrosetta.Pose = ...
clasher: pyrosetta.Pose = ...  # where the skeleton combined with this pose, it should not have clashes

xyz_skeleton = extract_coords(skeleton)
xyz_clasher = extract_coords(clasher)

all_distances = np.sqrt(np.sum((xyz_clasher[:, np.newaxis, :] - xyz_skeleton[np.newaxis, :, :]) ** 2, axis=-1))
n_clashing = np.count_nonzero(all_distances < 1.)

If designing interactions, the above but with < 3. will given backbone hydrogen bonding to any atom of the reference.
Say the skeleton passed, we then generate the sequence with proteinMPNN, thread it, relax it and tune the sequence. Now we can do some proper scoring.

∆∆G

The first is the predicted Gibbs free energy from a scorefunction. The catch is that this absolute energy, which should not be taken too seriously. The reason being is that they are optimised for ∆∆G calculations, so have empirical terms (i.e. fudge factors) to make them work, including the REF value (a per residue type weight, ref2015 values shown below), which means that were one to make a linear string of peptide it will not have a predicted absolute energy of folding of zero. However, in terms of magnitude is there is a screaming high value then the model is for the bin.

{'ALA': 1.32468, 'ARG': -0.09474, 'ASN': -1.34026, 'ASP': -2.14574, 'CYS': 3.25479, 'GLN': -1.45095, 'GLU': -2.72453, 'GLY': 0.79816, 'HIS': -0.30065, 'ILE': 2.30374, 'LEU': 1.66147, 'LYS': -0.71458, 'MET': 1.65735, 'PHE': 1.21829, 'PRO': -1.64321, 'SER': -0.28969, 'THR': 1.15175, 'TRP': 2.26099, 'TYR': 0.58223, 'VAL': 2.64269}

Per residue score

The scorefunction is a sum of the scores of each particle/residue, and the per-residue scores are obtainable. These are a common way to determine how accurate the model likely is. The per-residue score should not >+10 kcal/mol, and most residues in minimised PDB protein are under 5 kcal/mol, with one or two at the +8 kcal/mol mark likely due to removed ligands and crystalline waters (implicit solvent is an approximation of bulk solvent, not frozen waters). Two things are worth remembering, the ScoreFunction does not store any data, but Pose.energies does. But both can access this data. Personally I prefer the latter, but the former has weights which is handy.

prc: ModuleType = pyrosetta.rosetta.core
pr_scoring: ModuleType = pyrosetta.rosetta.core.scoring

scorefxn = pyrosetta.ScoreFunction = pyrosetta.get_fa_scorefxn()
# showcase weights:
weights: Dict[str, float] = {w.name: scorefxn.get_weight(w) for w in scorefxn.get_nonzero_weighted_scoretypes()}
print(weights)
# get per residue scores:
res_scores = []
for i in range(1, monomer.total_residue() + 1):
    v = pru.vector1_bool(pose.total_residue())
    v[i] = 1
    res_scores.append(scorefxn.get_sub_score(pose, v))

# some magic can be done with the annoying EMapVector
v: pr_scoring.EMapVector = scorefxn.weights()
print( v[pr_scoring.total_score], v[pr_scoring.fa_sol] )

Where using pose.energies is a lot easier

scorefxn(pose) # fills pose.energies
energies: pr_scoring.Energies = pose.energies()
typed_energies: Dict[str, float] = energies.active_total_energies()
a: npt.NDArray = energies.residue_total_energies_array()
energy = pd.DataFrame(a)  # amazingly civilised
# alternative a specific residue can be queried via EMapVector as above
v: pr_scoring.EMapVector = energies.residue_total_energies(1)

Do note that due to the fact that it takes two particles for a two body score (e.g. hydrogen bond), there’s a function to toggle whether these are halved or not (cf. this post).

scorefxn.weights()  # Create the EnergyMap
emopts = pr_scoring.methods.EnergyMethodOptions(scorefxn.energy_method_options())
emopts.hbond_options().decompose_bb_hb_into_pair_energies(True)
scorefxn.set_energy_method_options(emopts)

Parenthetically, if the pd.DataFrame(pose.energies().residue_total_energies_array()) was not saved, but is needed again, here’s a function to read the junk on a scored dumped PDB.

def rosetta_pdb_to_df(pdbblock: str) -> pd.DataFrame:
    parsable = False
    _data = []
    for line in pdbblock.split('\n'):
        if '#BEGIN_POSE_ENERGIES_TABLE' in line:
            parsable = True
            continue
        elif '#END_POSE_ENERGIES_TABLE' in line:
            break
        elif not parsable:
            continue
        parts = line.strip().split()
        if parts[0] == 'label':
            _data.append(parts)
        elif parts[0] == 'weights':
            _data.append([parts[0]] + list(map(float, parts[1:-1])) + [float('nan')])
        else:
            _data.append([parts[0]] + list(map(float, parts[1:])))
    data = pd.DataFrame(_data)
    data.columns = data.iloc[0]
    data = data.iloc[1:].copy()
    return data

Interface energy

The thing we want to have is a nice interface. There are two ways to this. One is civilised, but prone to segfaults: the InterfaceAnalyzerMover mover.

def score_interface(complex: Union[pyrosetta.Pose, Sequence[pyrosetta.Pose]], interface: str):
    if isinstance(complex, Sequence):
        _complex = complex[0].clone()
        for c in complex[1:]:
            add_chain(_complex, c)
        complex = _complex
    ia = pyrosetta.rosetta.protocols.analysis.InterfaceAnalyzerMover(interface)
    ia.apply(complex)
    return {'complex_energy': ia.get_complex_energy(),
            'separated_interface_energy': ia.get_separated_interface_energy(),
            'complexed_sasa': ia.get_complexed_sasa(),
            'crossterm_interface_energy': ia.get_crossterm_interface_energy(),
            'interface_dG': ia.get_interface_dG(),
            'interface_delta_sasa': ia.get_interface_delta_sasa()}

Like many other movers, no ligands are allowed.

Interface is a chain letter form, say A_BC. This crashes out above 3 chains I believe.

The other way is scoring the monomers in isolation (pose.split_chain) or translating a silly distance after having reset the pose.energies but preferably repacked the surface layer.

Changed sequence

As said above, how many residues changed when tweaked can be a proxy for something being not quite right.

Number of neighbours

pose2pdb = oligomer.pdb_info().pose2pdb
chainA_sele = pr_res.ChainSelector('A')
# boolean vector:
v = pr_res.NeighborhoodResidueSelector(chainA_sele, distance, False).apply(oligomer)
# vector of pose idxs:
close_residues = prc.select.get_residues_from_subset(v)
# array of PDB numbers
print( [pose2pdb(r) for r in close_residues] )

In another blog post I go through the two different residue neighbourhood selectors, but briefly, NeighborhoodResidueSelector is centroid (roughly the beta carbon), while CloseContactResidueSelector is closest atom.

cc_sele = pr_res.CloseContactResidueSelector()
cc_sele.central_residue_group_selector(chainA_sele)
cc_sele.threshold(3)  # Å

Hydrophobicity & co.

In a tool developed in OPIG, Therapeutic Antibody Profiler, an set of metrics where characterised, such as patches of surface hydrophobicity etc. which most likely would be well worth using in protein binder design. (I’ve not used these yet so don’t have a snippet to share).

Movement

One may be after rigidity. An MD run is best for that, but pyrosetta can do (discussed in detail in another blog post) via the BackrubMover mover.

def movement(original: pyrosetta.Pose,
             trials: int = 100, temperature: int = 1.0, replicate_number: int = 20) -> List[float]:
    scorefxn = pyrosetta.get_fa_scorefxn()
    backrub = pyrosetta.rosetta.protocols.backrub.BackrubMover()
    monégasque = pyrosetta.rosetta.protocols.monte_carlo.GenericMonteCarloMover(maxtrials=trials,
                                                                                max_accepted_trials=trials,
                                                                                # gen.max_accepted_trials() = 0
                                                                                task_scaling=5,
                                                                                # gen.task_scaling()
                                                                                mover=backrub,
                                                                                temperature=temperature,
                                                                                sample_type='low',
                                                                                drift=True)
    monégasque.set_scorefxn(scorefxn)
    # find most deviant
    rs = []
    for i in range(replicate_number):
        variant = original.clone()
        monégasque.apply(variant)
        if monégasque.accept_counter() > 0:
            rs.append(pr_scoring.bb_rmsd(variant, original))
        else:
            rs.append(float('nan'))
    return rs

Conclusion

Hopefully this exhaustive, and exhausting, overview of the movers and functions in PyRosetta, has adequately showcases what operations can be done in PyRosetta to better tune and analyse variants from RFdiffusion.


Footnote: ProteinMPNN helper as API functions

As mentioned, my designs had different lengths and I wanted to create the definitions for ProteinMPNN via my script as it was reading from tar.gz archives and other details. But briefly, I parsed the RFdiffusion skeletons thusly:

no_fix = []
definitions = []
global_fixed_chains: Dict[str, List[List[str]]] = {}
global_fixed_positions = {}

for path in paths:
    # ## Load data
    name = path.stem
    pdbblock = get_ATOM_only(path.read_text())
    # ## Parse definition (i.e. coordinates)
    definition = parse_PDBblock(pdbblock, name)
    definitions.append(definition)
    # ## Fixed chain
    fixed_chains = define_fixed_chains(definition, designed_chain_list='A') # chain A is designed
    global_fixed_chains[name] = fixed_chains
    # ## Fixed pos
    sequence = get_chainA_sequence(pdbblock)  # assuming the sequence to change is chain A
    # The skeleton will have stretches of glycines where the design was made:
    masked = re.sub(r'(G{3,})', lambda match: '_' * len(match.group(1)), sequence)
    fixed_list = [[i for i, r in enumerate(masked) if r == '_']]
    fixed_positions = define_unfixed_positions(definition, ['A'], fixed_list)
    global_fixed_positions[name] = fixed_positions

# write out definitions, global_fixed_chains, global_fixed_positions

with open(chains_definitions_path, 'w') as fh:  # only chains_definitions.jsonl is a JSONL
    for definition in definitions:
        fh.write(json.dumps(definition) + '\n')

with open(fixed_chains_path, 'w') as fh:
    json.dump(global_fixed_chains, fh)

with open(fixed_positions_path, 'w') as fh:
    json.dump(global_fixed_positions, fh)

Where the functions are:

"""
This is a minor refactoring of the original code, with the following changes:

`parse_PDBblock(pdbblock: str, name: str, chain_alphabet, ca_only=False)`
accepts a PDB block ``pdbblock`` and a name ``name`` and returns a dictionary that forms a line for the JSONL.

The code is mostly kept as was, with minor chances.
"""

import json
from typing import List, Sequence, Dict, Any
import numpy as np

alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
states = len(alpha_1)
alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
           'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']

aa_1_N = {a: n for n, a in enumerate(alpha_1)}
aa_3_N = {a: n for n, a in enumerate(alpha_3)}
aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}

def get_ATOM_only(pdbblock: str) -> str:
    """
    This gets all ATOM, regardless of name and chain
    """
    return '\n'.join([line for line in pdbblock.splitlines() if line.startswith('ATOM')])

def get_chainA_sequence(pdbblock: str) -> str:
    sequence = ''
    residues_seen = set()
    for line in pdbblock.splitlines():
        if line.startswith("ATOM") and " CA " in line and " A " in line:
            res_info = line[17:26]  # Residue name and number for uniqueness
            if res_info not in residues_seen:
                residues_seen.add(res_info)
                res_name = line[17:20].strip()
                sequence += three_to_one.get(res_name, '?')
    return sequence

three_to_one = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}
def AA_to_N(x):
    # ["ARND"] -> [[0,1,2,3]]
    x = np.array(x);
    if x.ndim == 0: x = x[None]
    return [[aa_1_N.get(a, states - 1) for a in y] for y in x]


def N_to_AA(x):
    # [[0,1,2,3]] -> ["ARND"]
    x = np.array(x);
    if x.ndim == 1: x = x[None]
    return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]


def inner_parse_PDBblock(pdbblock, atoms=['N', 'CA', 'C'], chain=None) -> tuple:
    '''
    input:  pdb_filename = PDB filename
            atoms = atoms to extract (optional)
    output: (length, atoms, coords=(x,y,z)), sequence
    '''
    xyz, seq, min_resn, max_resn = {}, {}, 1e6, -1e6
    for line in pdbblock.split("\n"):
        if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
            line = line.replace("HETATM", "ATOM  ")
            line = line.replace("MSE", "MET")

        if line[:4] == "ATOM":
            ch = line[21:22]
            if ch == chain or chain is None:
                atom = line[12:12 + 4].strip()
                resi = line[17:17 + 3]
                resn = line[22:22 + 5].strip()
                x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]

                if resn[-1].isalpha():
                    resa, resn = resn[-1], int(resn[:-1]) - 1
                else:
                    resa, resn = "", int(resn) - 1
                #         resn = int(resn)
                if resn < min_resn:
                    min_resn = resn
                if resn > max_resn:
                    max_resn = resn
                if resn not in xyz:
                    xyz[resn] = {}
                if resa not in xyz[resn]:
                    xyz[resn][resa] = {}
                if resn not in seq:
                    seq[resn] = {}
                if resa not in seq[resn]:
                    seq[resn][resa] = resi

                if atom not in xyz[resn][resa]:
                    xyz[resn][resa][atom] = np.array([x, y, z])
    # ^^ end of xyz loop

    # convert to numpy arrays, fill in missing values
    seq_, xyz_ = [], []
    try:
        for resn in range(min_resn, max_resn + 1):
            if resn in seq:
                for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k], 20))
            else:
                seq_.append(20)
            if resn in xyz:
                for k in sorted(xyz[resn]):
                    for atom in atoms:
                        if atom in xyz[resn][k]:
                            xyz_.append(xyz[resn][k][atom])
                        else:
                            xyz_.append(np.full(3, np.nan))
            else:
                for atom in atoms: xyz_.append(np.full(3, np.nan))
        return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_))
    except TypeError:
        return 'no_chain', 'no_chain'


def get_chain_alphabet():
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
                     'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
                     'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    return init_alphabet + extra_alphabet


def parse_PDBblock(pdbblock: str, name: str, ca_only=False):
    chain_alphabet = get_chain_alphabet()
    my_dict = {}
    s = 0
    concat_seq = ''
    for letter in chain_alphabet:
        if ca_only:
            sidechain_atoms = ['CA']
        else:
            sidechain_atoms = ['N', 'CA', 'C', 'O']
        xyz, seq = inner_parse_PDBblock(pdbblock=pdbblock, atoms=sidechain_atoms, chain=letter)
        if type(xyz) != str:
            concat_seq += seq[0]
            my_dict['seq_chain_' + letter] = seq[0]
            coords_dict_chain = {}
            if ca_only:
                coords_dict_chain['CA_chain_' + letter] = xyz.tolist()
            else:
                coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
                coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
                coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
                coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
            my_dict['coords_chain_' + letter] = coords_dict_chain
            s += 1
    my_dict['name'] = name
    my_dict['num_of_chains'] = s
    my_dict['seq'] = concat_seq
    if s < len(chain_alphabet):
        return my_dict
    else:
        raise Exception('Too many chains')


def define_fixed_chains(chains_definition: Dict[str, Any], designed_chain_list=('A',)):
    """
    The fixed chain definition file is a JSON dictionary of name to tuple/list of two: designed_chain_list and fixed_chain_list
    """
    all_chain_list = [item[-1:] for item in list(chains_definition) if item[:9] == 'seq_chain']  # ['A','B', 'C',...]
    fixed_chain_list = [letter for letter in all_chain_list if
                        letter not in designed_chain_list]  # fix/do not redesign these chains
    return list(designed_chain_list), fixed_chain_list


def define_global_fixed_chains(chains_definitions, global_designed_chain_list=('A',)):
    return {chains_definition['name']: define_fixed_chains(chains_definition, global_designed_chain_list)
            for chains_definition in chains_definitions}


def define_fixed_positions(chains_definition: Dict[str, Any],
                           designed_chain_list: Sequence[str],
                           fixed_list: Sequence[Sequence[int]]):
    all_chain_list = [item[-1:] for item in list(chains_definition) if item[:9] == 'seq_chain']
    fixed_position_dict = {}
    for i, chain in enumerate(designed_chain_list):
        fixed_position_dict[chain] = fixed_list[i]
    for chain in all_chain_list:
        if chain not in designed_chain_list:
            fixed_position_dict[chain] = []
    return fixed_position_dict


def define_unfixed_positions(chains_definition,
                             designed_chain_list: Sequence[str],
                             unfixed_list: Sequence[Sequence[int]]):
    """
    This will be an entry in a dictionary passed to ``chain_id_jsonl``
    (Misnomer: it's not a JSONL, but a JSON, only ``jsonl_path`` is)

    :param chains_definition:
    :param designed_chain_list:
    :param unfixed_list:
    :return:
    """
    all_chain_list = [item[-1:] for item in list(chains_definition) if item[:9] == 'seq_chain']
    fixed_position_dict = {}
    for chain in all_chain_list:
        seq_length = len(chains_definition[f'seq_chain_{chain}'])
        all_residue_list = (np.arange(seq_length) + 1).tolist()
        if chain not in designed_chain_list:
            fixed_position_dict[chain] = all_residue_list
        else:
            idx = np.argwhere(np.array(designed_chain_list) == chain)[0][0]
            fixed_position_dict[chain] = list(set(all_residue_list) - set(unfixed_list[idx]))
    return fixed_position_dict

Author