#!/usr/bin/env python3
"""
compare_aln.py — сравнение двух множественных выравниваний (аналог VerAlign)

Использование:
    python3 compare_aln.py aln1.fasta aln2.fasta [label1] [label2]

Входные данные:
    Два FASTA файла с выравниваниями одних и тех же последовательностей.
    Заголовки вида >sp|ACC|ID_SPECIES ..., >PDB:1x03:A ... или >1x03_A_... .
    Длины выравненных последовательностей в каждом файле должны совпадать.

Выходные данные (stdout):
    - Длины выравниваний
    - Число и % совпадающих колонок
    - Список блоков (s1,f1)=(s2,f2) с длиной >= 2
    - Одиночные совпадающие колонки вне блоков
    - Несовпадающие участки в первом выравнивании
"""

import sys


def parse_fasta(filepath):
    """
    Читает FASTA файл с выравниванием.
    Возвращает dict {short_id: gapped_sequence}.
    Поддерживает форматы заголовков:
      - UniProt: >sp|ACC|ID_SPECIES ... -> ID
      - PDBeFold: >PDB:1x03:A ...      -> 1x03
      - MUSCLE:   >1x03_A_Name ...     -> 1x03
    """
    seqs = {}
    current = None
    with open(filepath) as f:
        for line in f:
            line = line.rstrip()
            if line.startswith('>'):
                header = line[1:].split()[0]
                parts = line[1:].split('|')
                if len(parts) >= 3:
                    # UniProt: sp|ACC|ID_SPECIES
                    current = parts[2].split()[0]
                elif header.startswith('PDB:'):
                    # PDBeFold: PDB:1x03:A
                    current = header.split(':')[1].lower()
                else:
                    # MUSCLE/другие: 1x03_A_Name -> берём до первого _
                    current = header.split('_')[0].lower()
                seqs[current] = ''
            elif current:
                seqs[current] += line.replace(' ', '')
    return seqs


def compare_alignments(aln1, aln2, label1='Aln1', label2='Aln2'):
    """
    Сравнивает два выравнивания одних и тех же последовательностей.

    Алгоритм:
      Для каждой колонки c1 в aln1 находим остатки всех последовательностей.
      Смотрим, в какую колонку c2 в aln2 попадают те же остатки.
      Если для всех последовательностей это одна и та же колонка c2 —
      пара (c1, c2) считается совпадающей.

    Возвращает: (blocks, singles, matching)
      blocks  — список (s1, f1, s2, f2, length) блоков длиной >= 2
      singles — список (c1, c2) одиночных совпадений вне блоков
      matching — полный список совпадающих пар (c1, c2), 1-based
    """
    seqs = sorted(set(aln1.keys()) & set(aln2.keys()))
    if not seqs:
        print('ОШИБКА: нет общих последовательностей между файлами.')
        print(f'  {label1}: {sorted(aln1.keys())}')
        print(f'  {label2}: {sorted(aln2.keys())}')
        return [], [], []

    ncols1 = len(list(aln1.values())[0])
    ncols2 = len(list(aln2.values())[0])

    # остаток -> колонка в aln2
    res_to_col2 = {}
    for s in seqs:
        res_to_col2[s] = {}
        ri = 0
        for col, aa in enumerate(aln2[s]):
            if aa != '-':
                res_to_col2[s][ri] = col
                ri += 1

    # колонка -> индекс остатка в aln1
    col_to_res1 = {}
    for s in seqs:
        col_to_res1[s] = []
        ri = 0
        for aa in aln1[s]:
            col_to_res1[s].append(None if aa == '-' else ri)
            if aa != '-':
                ri += 1

    # колонка -> индекс остатка в aln2
    col_to_res2 = {}
    for s in seqs:
        col_to_res2[s] = []
        ri = 0
        for aa in aln2[s]:
            col_to_res2[s].append(None if aa == '-' else ri)
            if aa != '-':
                ri += 1

    # Поиск совпадающих пар колонок
    matching = []
    for c1 in range(ncols1):
        assignments = {s: col_to_res1[s][c1] for s in seqs
                       if col_to_res1[s][c1] is not None}
        if not assignments:
            continue
        c2_candidates = set()
        for s, ri in assignments.items():
            c2_candidates.add(res_to_col2[s].get(ri, None))
        if len(c2_candidates) != 1 or None in c2_candidates:
            continue
        c2 = list(c2_candidates)[0]
        # проверка в обратную сторону
        assignments2 = {s: col_to_res2[s][c2] for s in seqs
                        if col_to_res2[s][c2] is not None}
        if assignments2 == assignments:
            matching.append((c1 + 1, c2 + 1))

    # Блоки consecutive совпадений
    blocks = []
    if matching:
        bs1, bs2 = matching[0]
        ps1, ps2 = matching[0]
        for c1, c2 in matching[1:]:
            if c1 == ps1 + 1 and c2 == ps2 + 1:
                ps1, ps2 = c1, c2
            else:
                if ps1 - bs1 + 1 >= 2:
                    blocks.append((bs1, ps1, bs2, ps2, ps1 - bs1 + 1))
                bs1, bs2, ps1, ps2 = c1, c2, c1, c2
        if ps1 - bs1 + 1 >= 2:
            blocks.append((bs1, ps1, bs2, ps2, ps1 - bs1 + 1))

    in_block = set()
    for b in blocks:
        for i in range(b[0], b[1] + 1):
            in_block.add(i)
    singles = [(c1, c2) for c1, c2 in matching if c1 not in in_block]

    # ── Вывод ──
    print(f"\n{'='*60}")
    print(f"Сравнение: {label1} vs {label2}")
    print(f"Общие последовательности ({len(seqs)}): {', '.join(seqs)}")
    print(f"Длина {label1}: {ncols1}  |  Длина {label2}: {ncols2}")
    print(f"Совпадающих колонок: {len(matching)}")
    print(f"  % от {label1}: {100*len(matching)/ncols1:.1f}%")
    print(f"  % от {label2}: {100*len(matching)/ncols2:.1f}%")

    print(f"\nБлоки (длина >= 2), по убыванию длины:")
    if blocks:
        print(f"  {'(s1,f1)':>14} = {'(s2,f2)':>14}  длина")
        for b in sorted(blocks, key=lambda x: -x[4]):
            print(f"  ({b[0]},{b[1]}) = ({b[2]},{b[3]})  {b[4]}")
    else:
        print("  блоков нет")

    print(f"\nВсего блоков: {len(blocks)}")
    print(f"Одиночных совпадений вне блоков: {len(singles)}")
    if singles:
        print("  " + ", ".join(f"({c1},{c2})" for c1, c2 in singles))

    blocks_sorted = sorted(blocks, key=lambda x: x[0])
    print(f"\nНесовпадающие участки в {label1}:")
    prev = 0
    for b in blocks_sorted:
        if b[0] > prev + 1:
            print(f"  {prev+1}-{b[0]-1} (длина {b[0]-1-prev})")
        prev = b[1]
    if prev < ncols1:
        print(f"  {prev+1}-{ncols1} (длина {ncols1-prev})")

    return blocks, singles, matching


if __name__ == '__main__':
    if len(sys.argv) >= 3:
        path1, path2 = sys.argv[1], sys.argv[2]
        lab1 = sys.argv[3] if len(sys.argv) > 3 else path1
        lab2 = sys.argv[4] if len(sys.argv) > 4 else path2
        A = parse_fasta(path1)
        B = parse_fasta(path2)
        compare_alignments(A, B, label1=lab1, label2=lab2)
    else:
        #без аргументов берет следующие знанчения:
        A = parse_fasta('prakt-D-A-muscle.fasta')
        B = parse_fasta('prakt-D-B-mafft.fasta')
        C = parse_fasta('prakt-D-C-tcoffee.fasta')
        compare_alignments(A, B, label1='A(MUSCLE)', label2='B(MAFFT)')
        compare_alignments(A, C, label1='A(MUSCLE)', label2='C(T-Coffee)')