#!/usr/bin/env python3
import argparse
import sys

def compare_alignments(file1, file2, output_file=None):
    with open(file1, "r") as f1:
        f1_dct = {}
        current_key = None
        for line in f1:
            line = line.strip()
            if line.startswith(">"):
                current_key = line
                f1_dct[current_key] = []
                counter = 1
            elif current_key:
                for c in line:
                    if c.isalpha():
                        f1_dct[current_key].append(counter)
                        counter += 1
                    else:
                        f1_dct[current_key].append("-")

    with open(file2, "r") as f2:
        f2_dct = {}
        current_key = None
        for line in f2:
            line = line.strip()
            if line.startswith(">"):
                current_key = line
                f2_dct[current_key] = []
                counter = 1
            elif current_key:
                for c in line:
                    if c.isalpha():
                        f2_dct[current_key].append(counter)
                        counter += 1
                    else:
                        f2_dct[current_key].append("-")

    answer = []
    f1_keys = set(f1_dct.keys())
    f2_keys = set(f2_dct.keys())
    if f1_keys != f2_keys:
        missing_in_f2 = f1_keys - f2_keys
        missing_in_f1 = f2_keys - f1_keys
        print(f"Ошибка: Наборы ID не совпадают!")
        if missing_in_f2: print(f"Нет во втором файле: {missing_in_f2}")
        if missing_in_f1: print(f"Нет в первом файле: {missing_in_f1}")
        sys.exit(1)
    key_list = sorted([key for key in f1_dct])
    f1_columns = []
    f1_len = len(list(f1_dct.values())[0])
    for ind in range(f1_len):
        f1_columns.append([f1_dct[key][ind] for key in key_list])
    f1_aln_len = len(f1_columns)

    f2_columns = []
    f2_len = len(list(f2_dct.values())[0])
    for ind in range(f2_len):
        f2_columns.append([f2_dct[key][ind] for key in key_list])
    f2_aln_len = len(f2_columns)

    for column in range(len(f1_columns)):
        if f1_columns[column] in f2_columns:
            answer.append(tuple((column + 1, f2_columns.index(f1_columns[column]) + 1)))

    blocks = []
    if answer:
        current_block = [answer[0]]
        for i in range(1, len(answer)):
            prev = answer[i - 1]
            curr = answer[i]

            if curr[0] == prev[0] + 1 and curr[1] == prev[1] + 1:
                current_block.append(curr)
            else:
                if len(current_block) > 1:
                    blocks.append(current_block)
                current_block = [curr]
        if len(current_block) > 1:
            blocks.append(current_block)
    formatted_blocks = []
    for b in blocks:
        formatted_blocks.append({
            "f1_range": (b[0][0], b[-1][0]),
            "f2_range": (b[0][1], b[-1][1]),
            "length": len(b)
        })

    elements_in_blocks = set()
    for b in blocks:
        for tuple_item in b:
            elements_in_blocks.add(tuple_item)
    answer_set = set(answer)
    no_blocks = sorted(list(answer_set - elements_in_blocks))

    if output_file:
        with open(output_file, "w") as f:
            f.write(str(answer) + '\n' + "Длина первого выравнивания: " + str(f1_aln_len) + '\n' +
                    "Длина второго выравнивания: " + str(f2_aln_len) + '\n' +
                    f"Процент одинаково выравненных колонок от длины первого выравнивания: {round((len(answer)/f1_aln_len)*100, 2)}%" + '\n'
                    f"Процент одинаково выравненных колонок от длины второго выравнивания: {round((len(answer)/f2_aln_len)*100, 2)}%" + '\n'
                    f'Список блоков одинаково выровненных колонок - {formatted_blocks}' + '\n'
                    f'Список одинаково выровненных колонок, не входящих в блоки - {no_blocks}')
        print(f"Результат сохранен в {output_file}")
    else:
        print(answer)
        print("Длина первого выравнивания:", f1_aln_len)
        print("Длина второго выравнивания:", f2_aln_len)
        print(f"Процент одинаково выравненных колонок от длины первого выравнивания: {round((len(answer)/f1_aln_len)*100, 2)}%")
        print(f"Процент одинаково выравненных колонок от длины второго выравнивания: {round((len(answer)/f2_aln_len)*100, 2)}%")
        print(f'Список блоков одинаково выровненных колонок - {formatted_blocks}')
        print(f'Список одинаково выровненных колонок, не входящих в блоки - {no_blocks}')


parser = argparse.ArgumentParser(
    description="""Скрипт для сравнения выравниваний.

ТРЕБОВАНИЯ К ВХОДНЫМ ФАЙЛАМ:
1. Формат: FASTA.
2. Идентификаторы: Названия последовательностей в обоих файлах должны быть идентичными.
3. Длина: Все последовательности внутри ОДНОГО файла должны иметь одинаковую длину.
4. Символы: Буквы считаются значимыми символами, остальные — гэпами (пробелами).""",

    formatter_class=argparse.RawDescriptionHelpFormatter,
    epilog="""
ФОРМАТ ВЫХОДНОГО ФАЙЛА:
Список кортежей, где (x, y):
  x — номер колонки в первом файле
  y — номер соответствующей колонки во втором файле
Пример: [(1, 1), (2, 5), (3, 6)]
Длина первого выравнивания: число
Длина второго выравнивания: число
Процент одинаково выровненных колонок от длины первого выравнивания: число %
Процент одинаково выровненных колонок от длины второго выравнивания: число %
Список блоков одинаково выровненных колонок. В нем выдача имеет вид {‘f1_range’: (x1, y1), ‘f2_range’: (x2, y2)}, где f1_range - первый алгоритм выравнивания, f2_range - второй алгоритм выравнивания, x - номер столбца начала блока, y - номер столбца конца блока.
Список одинаково выровненных колонок, не входящих в блоки, он имеет вид - [(x, y)], где x - координата столбца первого алгоритма выравнивания, y - координата столбца второго алгоритма выравнивания



ПРИМЕЧАНИЕ:
Если идентификаторы в файлах не совпадут, программа выдаст KeyError.
"""
)

parser.add_argument("file1", help="Путь к первому файлу с выравниванием")
parser.add_argument("file2", help="Путь к второму файлу с выравниванием")
parser.add_argument("-o", "--output", help="Путь к файлу для сохранения результата")

if len(sys.argv) == 1:
    parser.print_help()
    sys.exit(1)

args = parser.parse_args()

compare_alignments(args.file1, args.file2, args.output)

