#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import argparse
import sys
from matplotlib.ticker import FuncFormatter

def format_ticks(value, pos):
    """Форматирование меток: K для тысяч, M для миллионов."""
    if value >= 1e6:
        return f'{value/1e6:.1f}M'
    elif value >= 1e3:
        return f'{int(value/1e3)}K'
    else:
        return str(int(value))

def parse_blast_outfmt7(filename):
    """
    Парсит файл BLAST выдачи в формате -outfmt 7.
    Возвращает:
        query_id : str
        subject_id : str
        hits : list of tuples (qstart, qend, sstart, send)
        max_q : int (максимальная координата запроса)
        max_s : int (максимальная координата субъекта)
    """
    query_id = None
    subject_id = None
    hits = []
    qstarts = []
    qends = []
    sstarts = []
    sends = []

    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith('# Query:'):
                query_id = line.split(':', 1)[1].strip()
            elif line.startswith('# Subject:'):
                subject_id = line.split(':', 1)[1].strip()
            elif not line.startswith('#'):
                fields = line.split('\t')
                if len(fields) < 12:
                    continue
                try:
                    qstart = int(fields[6])
                    qend = int(fields[7])
                    sstart = int(fields[8])
                    send = int(fields[9])
                except ValueError:
                    continue
                hits.append((qstart, qend, sstart, send))
                qstarts.append(qstart)
                qends.append(qend)
                sstarts.append(sstart)
                sends.append(send)

    if not hits:
        sys.exit("Ошибка: не найдено ни одного выравнивания в файле.")

    max_q = max(max(qstarts), max(qends))
    max_s = max(max(sstarts), max(sends))
    return query_id, subject_id, hits, max_q, max_s

def setup_axes(ax, max_x, max_y, xlabel, ylabel,
               title_fontsize=16,
               label_fontsize=14,
               tick_fontsize=12):
    """
    Настраивает оси: пределы, деления, подписи, фон, формат меток.
    Параметры:
        title_fontsize - размер шрифта заголовка
        label_fontsize - размер подписей осей
        tick_fontsize  - размер меток делений
    """
    ax.set_facecolor('#e8f0e8')  # зеленоватый фон

    ax.set_xlim(0, max_x)
    ax.set_ylim(0, max_y)

    ax.set_xlabel(xlabel, fontsize=label_fontsize)
    ax.set_ylabel(ylabel, rotation=90, labelpad=15, fontsize=label_fontsize)

    # Крупные деления (с подписями) – каждые 500000
    major_ticks_x = np.arange(0, max_x + 1, 500000)
    major_ticks_y = np.arange(0, max_y + 1, 500000)
    if major_ticks_x[-1] != max_x:
        major_ticks_x = np.append(major_ticks_x, max_x)
    if major_ticks_y[-1] != max_y:
        major_ticks_y = np.append(major_ticks_y, max_y)

    ax.set_xticks(major_ticks_x)
    ax.set_yticks(major_ticks_y)

    # Форматирование меток
    ax.xaxis.set_major_formatter(FuncFormatter(format_ticks))
    ax.yaxis.set_major_formatter(FuncFormatter(format_ticks))

    # Мелкие деления (без подписей) – каждые 50000
    minor_ticks_x = np.arange(0, max_x + 1, 50000)
    minor_ticks_y = np.arange(0, max_y + 1, 50000)

    ax.set_xticks(minor_ticks_x, minor=True)
    ax.set_yticks(minor_ticks_y, minor=True)

    # Настройка внешнего вида делений с учётом размеров шрифта
    ax.tick_params(axis='both', which='major', length=10, width=1.5, labelsize=tick_fontsize)
    ax.tick_params(axis='both', which='minor', length=5, width=1.0, labelsize=0)

    # Сетка: major – сплошная, minor – пунктир
    ax.grid(True, which='major', linestyle='-', linewidth=0.5, color='gray')
    ax.grid(True, which='minor', linestyle=':', linewidth=0.3, color='lightgray')

def plot_dotplot(ax, hits, alpha=0.5, linewidth=1.2, use_points=False, point_size=2):
    """
    Рисует выравнивания:
      - если use_points=False – линии от (qstart,sstart) до (qend,send)
      - если use_points=True – точки в середине каждого выравнивания
    """
    if use_points:
        x_centers = []
        y_centers = []
        for qstart, qend, sstart, send in hits:
            x_centers.append((qstart + qend) / 2)
            y_centers.append((sstart + send) / 2)
        ax.scatter(x_centers, y_centers, s=point_size, c='black', alpha=alpha, marker='.')
    else:
        for qstart, qend, sstart, send in hits:
            ax.plot([qstart, qend], [sstart, send],
                    color='black', linewidth=linewidth, alpha=alpha)

def main():
    parser = argparse.ArgumentParser(description='Построение карты локального сходства (dot plot) по результатам BLAST -outfmt 7')
    parser.add_argument('blast_file', help='Файл с результатами BLAST (формат -outfmt 7)')
    parser.add_argument('--alpha', type=float, default=0.5, help='Прозрачность элементов (по умолч. 0.5)')
    parser.add_argument('--linewidth', type=float, default=1.2, help='Толщина линий (по умолч. 1.2)')
    parser.add_argument('--points', action='store_true', help='Рисовать точки вместо линий')
    parser.add_argument('--point-size', type=float, default=2.0, help='Размер точек (по умолч. 2.0)')
    parser.add_argument('--output', '-o', default='dotplot.png', help='Имя выходного файла (по умолч. dotplot.png)')
    parser.add_argument('--dpi', type=int, default=300, help='Разрешение сохранения (по умолч. 300)')

    # Параметры масштабирования текста
    parser.add_argument('--title-fontsize', type=int, default=16, help='Размер шрифта заголовка (по умолч. 16)')
    parser.add_argument('--label-fontsize', type=int, default=14, help='Размер шрифта подписей осей (по умолч. 14)')
    parser.add_argument('--tick-fontsize', type=int, default=12, help='Размер шрифта меток делений (по умолч. 12)')

    args = parser.parse_args()

    # Парсинг данных
    query_id, subject_id, hits, max_q, max_s = parse_blast_outfmt7(args.blast_file)

    if query_id is None:
        query_id = 'Query'
    if subject_id is None:
        subject_id = 'Subject'

    # Фигура с соотношением 16:9 (альбомная)
    fig, ax = plt.subplots(figsize=(16, 9))
    setup_axes(ax, max_q, max_s, query_id, subject_id,
               title_fontsize=args.title_fontsize,
               label_fontsize=args.label_fontsize,
               tick_fontsize=args.tick_fontsize)

    plot_dotplot(ax, hits,
                 alpha=args.alpha,
                 linewidth=args.linewidth,
                 use_points=args.points,
                 point_size=args.point_size)

    ax.set_title(f'blastn: {query_id} vs {subject_id}',
                 fontsize=args.title_fontsize)
    plt.tight_layout()
    plt.savefig(args.output, dpi=args.dpi)
    print(f"График сохранён в {args.output}")

if __name__ == '__main__':
    main()
