# save as dna_protein_contacts.py and run inside PyMOL with: run dna_protein_contacts.py
# or paste into PyMOL command line.
from pymol import cmd
import math
from collections import defaultdict

# --- CONFIG: объект и расстояния ---
OBJ_NAME = "1TRO"   # <- поменяйте, если ваш объект называется иначе (или загрузите XXXX.pdb как 1TRO)
POLAR_DIST = 3.5     # Å for polar-polar
NONPOLAR_DIST = 4.5  # Å for nonpolar-nonpolar

# --- определения атомных типов (по вашему описанию) ---
POLAR_ATOMS = set(['N','O'])
NONPOLAR_ATOMS = set(['C','P','S'])

# --- атомные имена для частей нуклеотида ---
# имена атомов сахара (вариации с и без штрихов)
SUGAR_ATOM_NAMES = set([
    "C1'", "C2'", "C3'", "C4'", "C5'", "O4'", "O3'", "O2'", "C5*", "C4*", "C3*", "C2*", "C1*", "O4*", "O3*", "O2*",
    "C1", "C2", "C3", "C4", "C5", "O4", "O3", "O2"  # на случай нестандартных имён
])

# атомы фосфатной группы
PHOSPHATE_ATOM_NAMES = set(["P","OP1","OP2","O1P","O2P","O5'","O5*"])

# Для оснований — остальные атомы нуклеотида не отнесённые к сахару/фосфату считаются атомами основания
# (это удобная и часто работоспособная классификация)

# --- эвристика определения большой/малой бороздки для атома основания ---
# Простая эвристика: для пуринов и пиримидинов существуют атомы, преимущественно выходящие в большую бороздку.
# Ниже — словари атомных имён, которые мы относим к большой / малой стороне.
# Это НЕ идеально, но даёт разумное приближение для подсчёта контактов по сторонам бороздок.
MAJOR_GROOVE_ATOMS = set([
    # пурины (A,G) типичные атомы на стороне большой бороздки
    "N7","C8","O6","N6","C5","C4","N3","C2",  # часть из них часто обращена в большую бороздку
    # пиримидины (C,T) — некоторые атомы более направлены в большую бороздку
    "O4","C5","C6","N4"
])
MINOR_GROOVE_ATOMS = set([
    # атомы, чаще обращённые в малую бороздку
    "N2","N3","O2","N1"
])

# --- вспомогательные функции ---
def dist(a, b):
    return math.sqrt((a[0]-b[0])**2 + (a[1]-b[1])**2 + (a[2]-b[2])**2)

def atom_key(atom):
    # уникальный ключ для атома
    return (atom.chain, atom.resi, atom.resn, atom.name, atom.index)

# --- получаем модель объекта ---
model = cmd.get_model(OBJ_NAME)

# разделим атомы на две группы: белок и ДНК
protein_atoms = []
dna_atoms = []

for a in model.atom:
    # a.segi, a.chain, a.resn, a.resi, a.name, a.index доступны
    # определить «белок» через полное имя остатка: аминокислоты — трёхбуквенные коды (ALA, GLY и т.д.)
    resn = a.resn.strip().upper()
    # стандартная проверка: если residue name одно из стандартных нуклеотидных, считаем ДНК
    dna_names = set(["DA","DT","DG","DC","A","T","G","C","DI","DU"])  # добавлены вариации
    if resn in dna_names:
        dna_atoms.append(a)
    else:
        # если residue name — аминокислота (трёхбуквенный), либо hetero flag? Мы считаем всё остальное белком.
        protein_atoms.append(a)

# --- инициализация счётчиков ---
# словарь: counters[part]['polar'], ['nonpolar'], ['total']
parts = ["sugar","phosphate","base_major","base_minor"]
counters = {p: {'polar':0, 'nonpolar':0, 'total':0} for p in parts}
# также посчитаем все базовые контакты в отдельную категорию base_unknown, если не уверены
counters['base_unknown'] = {'polar':0, 'nonpolar':0, 'total':0}

# список найденных контактов (для печати/проверки)
contact_list = []

# Пройдём по всем парам белок-ДНК (можно ускорять пространственным индексом, но для обычных структур это нормально)
for pa in protein_atoms:
    p_elem = pa.symbol.strip().upper() if hasattr(pa, 'symbol') else pa.name[0].upper()
    for da in dna_atoms:
        d_elem = da.symbol.strip().upper() if hasattr(da, 'symbol') else da.name[0].upper()
        # классификация по полярности
        is_polar_pair = (p_elem in POLAR_ATOMS) and (d_elem in POLAR_ATOMS)
        is_nonpolar_pair = (p_elem in NONPOLAR_ATOMS) and (d_elem in NONPOLAR_ATOMS)
        if not (is_polar_pair or is_nonpolar_pair):
            continue  # по правилам нас интересуют только пары поляр-поляр или неполяр-неполяр

        # рассчитаем расстояние
        pcoord = (pa.coord[0], pa.coord[1], pa.coord[2])
        dcoord = (da.coord[0], da.coord[1], da.coord[2])
        r = dist(pcoord, dcoord)

        # применим пороги
        if is_polar_pair and r <= POLAR_DIST:
            kind = 'polar'
        elif is_nonpolar_pair and r <= NONPOLAR_DIST:
            kind = 'nonpolar'
        else:
            continue

        # определить к какой части нуклеотида принадлежит атом да
        dname = da.name.strip()
        part = None
        if dname in PHOSPHATE_ATOM_NAMES:
            part = 'phosphate'
        elif dname in SUGAR_ATOM_NAMES:
            part = 'sugar'
        else:
            # считаем атомом основания — определим большую/малую бороздку по имени атома (эвристика)
            uname = dname.replace("'", "").upper()  # нормализуем имя
            if uname in MAJOR_GROOVE_ATOMS:
                part = 'base_major'
            elif uname in MINOR_GROOVE_ATOMS:
                part = 'base_minor'
            else:
                # если не уверены — пометим как base_unknown
                part = 'base_unknown'

        counters.setdefault(part, {'polar':0,'nonpolar':0,'total':0})
        counters[part][kind] += 1
        counters[part]['total'] += 1

        contact_list.append({
            'prot_atom': (pa.chain, pa.resi, pa.resn, pa.name),
            'dna_atom': (da.chain, da.resi, da.resn, da.name),
            'dna_part': part,
            'kind': kind,
            'dist': r
        })

# --- печать результатов в форме таблицы ---
print("\nКонтакты атомов белка с ДНК в структуре: {}\n".format(OBJ_NAME))
print("Таблица. Контакты разного типа в комплексе {} (числа — количество пар атомов)\n".format(OBJ_NAME))
header = "{:<40s}{:>12s}{:>12s}{:>12s}".format("Контакты атомов белка с", "Полярные", "Неполярные", "Всего")
print(header)
print("-"*len(header))

rows = [
    ("остатками 2'-дезоксирибозы", "sugar"),
    ("остатками фосфорной кислоты", "phosphate"),
    ("остатками азотистых оснований со стороны большой бороздки", "base_major"),
    ("остатками азотистых оснований со стороны малой бороздки", "base_minor"),
    ("остатками азотистых оснований (неуточнённые по бороздке)", "base_unknown")
]

for label, key in rows:
    pol = counters.get(key,{}).get('polar',0)
    nonpol = counters.get(key,{}).get('nonpolar',0)
    tot = counters.get(key,{}).get('total',0)
    print("{:<40s}{:>12d}{:>12d}{:>12d}".format(label, pol, nonpol, tot))

# --- опционально: вывести N первых контактов для проверки ---
print("\nПримеры контактов (первые 50):")
for i,c in enumerate(contact_list[:50]):
    pa = c['prot_atom']; da = c['dna_atom']
    print("{:2d}: Prot {}:{}({}) atom {}  ---  DNA {}:{}({}) atom {}  | part={} kind={} dist={:.2f}Å".format(
        i+1, pa[0], pa[1], pa[2], pa[3], da[0], da[1], da[2], da[3], c['dna_part'], c['kind'], c['dist']
    ))

print("\nВсего найдено контактов (пар атомов):", len(contact_list))
print("Скрипт завершил работу. Если хотите, могу сохранить контакты в CSV для анализа.")
