#!/usr/bin/env python3

import argparse
import sys
from collections import defaultdict


GAP_SYMBOLS = {"-", "."}


def read_alignment(filename):
    """
    Считывает FASTA-выравнивание.

    Возвращает:
    sequences: dict[str, str]
        Ключ — ID последовательности без символа '>'.
        Значение — строка выравнивания.
    """
    sequences = {}
    current_id = None

    try:
        with open(filename, "r", encoding="utf-8-sig") as file:
            for line_number, line in enumerate(file, start=1):
                line = line.strip()

                if not line:
                    continue

                if line.startswith(">"):
                    current_id = line[1:].strip()

                    if not current_id:
                        raise ValueError(
                            f"{filename}: пустой идентификатор в строке {line_number}"
                        )

                    if current_id in sequences:
                        raise ValueError(
                            f"{filename}: повторяющийся идентификатор '{current_id}'"
                        )

                    sequences[current_id] = []
                else:
                    if current_id is None:
                        raise ValueError(
                            f"{filename}: найдена последовательность до первого FASTA-заголовка "
                            f"в строке {line_number}"
                        )

                    sequences[current_id].append(line)

    except FileNotFoundError:
        raise FileNotFoundError(f"Файл не найден: {filename}")

    if not sequences:
        raise ValueError(f"{filename}: файл пустой или не содержит FASTA-записей")

    sequences = {
        seq_id: "".join(parts)
        for seq_id, parts in sequences.items()
    }

    return sequences


def validate_alignment(sequences, filename):
    """
    Проверяет, что все последовательности внутри одного выравнивания
    имеют одинаковую длину.
    """
    lengths = {
        seq_id: len(seq)
        for seq_id, seq in sequences.items()
    }

    unique_lengths = set(lengths.values())

    if len(unique_lengths) != 1:
        message = [f"{filename}: последовательности имеют разную длину:"]
        for seq_id, length in lengths.items():
            message.append(f"  {seq_id}: {length}")
        raise ValueError("\n".join(message))

    alignment_length = next(iter(unique_lengths))

    if alignment_length == 0:
        raise ValueError(f"{filename}: длина выравнивания равна 0")

    return alignment_length


def ungap(sequence):
    """
    Возвращает последовательность без гэпов.
    """
    return "".join(
        char for char in sequence
        if char not in GAP_SYMBOLS
    )


def validate_same_sequences(aln1, aln2):
    """
    Проверяет, что два выравнивания содержат одни и те же ID
    и одни и те же последовательности без гэпов.
    """
    ids1 = set(aln1.keys())
    ids2 = set(aln2.keys())

    if ids1 != ids2:
        missing_in_2 = sorted(ids1 - ids2)
        missing_in_1 = sorted(ids2 - ids1)

        message = ["Наборы ID в двух файлах не совпадают."]

        if missing_in_2:
            message.append(
                "ID, отсутствующие во втором файле: " + ", ".join(missing_in_2)
            )

        if missing_in_1:
            message.append(
                "ID, отсутствующие в первом файле: " + ", ".join(missing_in_1)
            )

        raise ValueError("\n".join(message))

    for seq_id in sorted(ids1):
        seq1 = ungap(aln1[seq_id])
        seq2 = ungap(aln2[seq_id])

        if seq1 != seq2:
            raise ValueError(
                f"Последовательность '{seq_id}' без гэпов отличается в двух файлах.\n"
                f"Длина в первом файле: {len(seq1)}\n"
                f"Длина во втором файле: {len(seq2)}"
            )


def alignment_to_position_matrix(alignment, ids):
    """
    Преобразует выравнивание в матрицу позиций.

    Для каждой последовательности:
    - остатки получают номера 1, 2, 3, ...
    - гэпы обозначаются '-'

    Пример:
    A-CG -> [1, "-", 2, 3]
    """
    position_matrix = {}

    for seq_id in ids:
        sequence = alignment[seq_id]
        position = 1
        encoded_sequence = []

        for char in sequence:
            if char in GAP_SYMBOLS:
                encoded_sequence.append("-")
            else:
                encoded_sequence.append(position)
                position += 1

        position_matrix[seq_id] = encoded_sequence

    return position_matrix


def get_columns(position_matrix, ids, alignment_length):
    """
    Строит список колонок из матрицы позиций.

    Каждая колонка представляется кортежем.
    Например:
    (1, "-", 5)
    """
    columns = []

    for column_index in range(alignment_length):
        column = tuple(
            position_matrix[seq_id][column_index]
            for seq_id in ids
        )
        columns.append(column)

    return columns


def compare_columns(columns1, columns2):
    """
    Сравнивает колонки двух выравниваний.

    Возвращает список пар (i, j), где:
    i — номер колонки в первом выравнивании,
    j — номер соответствующей колонки во втором выравнивании.

    Полностью гэповые колонки игнорируются.
    Нумерация колонок начинается с 1.
    """
    columns2_to_indices = defaultdict(list)

    for index, column in enumerate(columns2, start=1):
        if not is_all_gap_column(column):
            columns2_to_indices[column].append(index)

    result = []

    for i, column in enumerate(columns1, start=1):
        if is_all_gap_column(column):
            continue

        if column in columns2_to_indices:
            for j in columns2_to_indices[column]:
                result.append((i, j))

    return result

def is_all_gap_column(column):
    return all(value == "-" for value in column)

def find_blocks(equal_columns):
    """
    Находит блоки подряд идущих одинаково выровненных колонок.

    Блоком считается серия пар:
    (i, j), (i + 1, j + 1), (i + 2, j + 2), ...

    Возвращает:
    blocks:
        список блоков, каждый блок — список пар (i, j)
    single_columns:
        совпадающие колонки, не входящие в блоки длины >= 2
    """
    if not equal_columns:
        return [], []

    equal_columns = sorted(equal_columns)

    blocks = []
    current_block = [equal_columns[0]]

    for current in equal_columns[1:]:
        previous = current_block[-1]

        if current[0] == previous[0] + 1 and current[1] == previous[1] + 1:
            current_block.append(current)
        else:
            if len(current_block) >= 2:
                blocks.append(current_block)
            current_block = [current]

    if len(current_block) >= 2:
        blocks.append(current_block)

    elements_in_blocks = {
        pair
        for block in blocks
        for pair in block
    }

    single_columns = [
        pair
        for pair in equal_columns
        if pair not in elements_in_blocks
    ]

    return blocks, single_columns


def format_blocks(blocks):
    """
    Преобразует блоки в удобный для вывода формат.
    """
    formatted = []

    for block in blocks:
        formatted.append({
            "f1_range": (block[0][0], block[-1][0]),
            "f2_range": (block[0][1], block[-1][1]),
            "length": len(block)
        })

    return formatted

def write_output(
    output_file,
    equal_columns,
    aln1_length,
    aln2_length,
    percent1,
    percent2,
    formatted_blocks,
    single_columns
):
    """
    Записывает результат сравнения в файл в формате, удобном для отчёта.
    """
    with open(output_file, "w", encoding="utf-8") as file:
        file.write("# Сравнение двух выравниваний\n\n")

        file.write("## Статистика\n")
        file.write(f"Длина первого выравнивания: {aln1_length}\n")
        file.write(f"Длина второго выравнивания: {aln2_length}\n")
        file.write(f"Количество одинаково выровненных колонок: {len(equal_columns)}\n")
        file.write(
            f"Процент одинаково выровненных колонок от длины первого выравнивания: "
            f"{percent1:.2f}%\n"
        )
        file.write(
            f"Процент одинаково выровненных колонок от длины второго выравнивания: "
            f"{percent2:.2f}%\n"
        )

        file.write("\n## Блоки одинаково выровненных колонок\n")
        file.write("# Формат: (s1,f1)=(s2,f2), длина блока\n")

        if formatted_blocks:
            for block in formatted_blocks:
                s1, f1 = block["f1_range"]
                s2, f2 = block["f2_range"]
                length = block["length"]
                file.write(f"({s1},{f1})=({s2},{f2}), длина {length}\n")
        else:
            file.write("Блоков длиной >= 2 не найдено.\n")

        file.write("\n## Одинаково выровненные колонки, не входящие в блоки\n")
        file.write("# Формат: колонка_в_первом_выравнивании = колонка_во_втором_выравнивании\n")

        if single_columns:
            for i, j in single_columns:
                file.write(f"{i}={j}\n")
        else:
            file.write("Нет одиночных одинаково выровненных колонок.\n")

        file.write("\n## Полный список одинаково выровненных колонок\n")
        file.write("# Формат: колонка_в_первом_выравнивании колонка_во_втором_выравнивании\n")

        for i, j in equal_columns:
            file.write(f"{i}\t{j}\n")

def compare_alignments(file1, file2, output_file=None):
    aln1 = read_alignment(file1)
    aln2 = read_alignment(file2)

    aln1_length = validate_alignment(aln1, file1)
    aln2_length = validate_alignment(aln2, file2)

    validate_same_sequences(aln1, aln2)

    ids = sorted(aln1.keys())

    matrix1 = alignment_to_position_matrix(aln1, ids)
    matrix2 = alignment_to_position_matrix(aln2, ids)

    columns1 = get_columns(matrix1, ids, aln1_length)
    columns2 = get_columns(matrix2, ids, aln2_length)

    equal_columns = compare_columns(columns1, columns2)

    blocks, single_columns = find_blocks(equal_columns)
    formatted_blocks = format_blocks(blocks)

    percent1 = len(equal_columns) / aln1_length * 100
    percent2 = len(equal_columns) / aln2_length * 100

    print("Сравнение завершено.")
    print(f"Длина первого выравнивания: {aln1_length}")
    print(f"Длина второго выравнивания: {aln2_length}")
    print(f"Количество одинаково выровненных колонок: {len(equal_columns)}")
    print(
        "Процент одинаково выровненных колонок от длины первого выравнивания: "
        f"{percent1:.2f}%"
    )
    print(
        "Процент одинаково выровненных колонок от длины второго выравнивания: "
        f"{percent2:.2f}%"
    )
    print(f"Список одинаково выровненных колонок: {equal_columns}")
    print(f"Список блоков одинаково выровненных колонок: {formatted_blocks}")
    print(
        "Список одинаково выровненных колонок, не входящих в блоки: "
        f"{single_columns}"
    )

    if output_file:
        write_output(
            output_file=output_file,
            equal_columns=equal_columns,
            aln1_length=aln1_length,
            aln2_length=aln2_length,
            percent1=percent1,
            percent2=percent2,
            formatted_blocks=formatted_blocks,
            single_columns=single_columns
        )
        print(f"Результат сохранён в файл: {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description="""
Программа для сравнения двух разных выравниваний одних и тех же последовательностей.

Алгоритм:
1. Для каждой последовательности в выравнивании негэповые символы нумеруются
   по порядку: 1, 2, 3, ...
2. Гэпы обозначаются символом '-'.
3. Каждая колонка выравнивания представляется набором позиций остатков
   во всех последовательностях.
4. Колонки двух выравниваний считаются одинаково выровненными, если их
   такие представления полностью совпадают.

Требования к входным файлам:
1. Формат входных файлов: FASTA.
2. В обоих файлах должны быть одинаковые идентификаторы последовательностей.
3. Последовательности с одинаковыми ID в двух файлах должны совпадать
   после удаления гэпов.
4. Внутри каждого файла все строки выравнивания должны иметь одинаковую длину.
5. Гэпами считаются символы '-' и '.'.
6. Все остальные символы считаются символами последовательности.
""",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Пример запуска:

  python3 compare_alignments.py aln1.fasta aln2.fasta

С сохранением результата в файл:

  python3 compare_alignments.py aln1.fasta aln2.fasta -o result.txt

Формат выходного файла:
1. Сначала выводится список одинаково выровненных колонок.
   Каждая строка имеет формат:

      i    j

   где:
   i — номер колонки в первом выравнивании,
   j — номер соответствующей колонки во втором выравнивании.

2. Далее выводится статистика:
   - длина первого выравнивания;
   - длина второго выравнивания;
   - процент одинаково выровненных колонок от длины первого выравнивания;
   - процент одинаково выровненных колонок от длины второго выравнивания;
   - список блоков одинаково выровненных колонок;
   - список одинаково выровненных колонок, не входящих в блоки.
"""
    )

    parser.add_argument(
        "file1",
        help="Путь к первому FASTA-файлу с выравниванием"
    )

    parser.add_argument(
        "file2",
        help="Путь ко второму FASTA-файлу с выравниванием"
    )

    parser.add_argument(
        "-o",
        "--output",
        help="Путь к выходному файлу"
    )

    args = parser.parse_args()

    try:
        compare_alignments(args.file1, args.file2, args.output)
    except Exception as error:
        print(f"Ошибка: {error}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()
