"""
find_common_mnemonics.py

Usage:
    python find_common_mnemonics.py ecoli.txt bacsu.txt

Reads two TSV files from UniProt (with header row),
extracts Entry Name column, finds pairs with matching
functional mnemonics (part before the underscore).
Excludes Y* and ENO.
"""

import sys

def get_mnemonic(entry_name):
    parts = entry_name.strip().split('_')
    if len(parts) >= 2:
        return parts[0]
    return None

def load_entry_names(filename):
    ids = []
    with open(filename, encoding='utf-8') as f:
        first_line = f.readline()
        # Detect if TSV with header
        if 'Entry Name' in first_line or 'Entry\t' in first_line:
            cols = first_line.strip().split('\t')
            try:
                name_col = cols.index('Entry Name')
            except ValueError:
                name_col = 1  # fallback: second column
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) > name_col:
                    ids.append(parts[name_col].strip())
        else:
            # Plain list format
            ids.append(first_line.strip())
            for line in f:
                line = line.strip()
                if line:
                    ids.append(line)
    return ids

def main():
    if len(sys.argv) != 3:
        print("Usage: python find_common_mnemonics.py ecoli.txt bacsu.txt")
        sys.exit(1)

    ecoli_ids = load_entry_names(sys.argv[1])
    bacsu_ids = load_entry_names(sys.argv[2])

    print(f"Loaded {len(ecoli_ids)} ECOLI entries, e.g.: {ecoli_ids[:3]}")
    print(f"Loaded {len(bacsu_ids)} BACSU entries, e.g.: {bacsu_ids[:3]}")

    ecoli_dict = {}
    for uid in ecoli_ids:
        m = get_mnemonic(uid)
        if m:
            ecoli_dict[m] = uid

    bacsu_dict = {}
    for uid in bacsu_ids:
        m = get_mnemonic(uid)
        if m:
            bacsu_dict[m] = uid

    common = sorted(set(ecoli_dict.keys()) & set(bacsu_dict.keys()))

    lines = []
    lines.append(f"Total common mnemonics: {len(common)}")
    lines.append(f"\nPairs (excluding Y* and ENO):")
    lines.append(f"{'Mnemonic':<12} {'ECOLI ID':<20} {'BACSU ID':<20}")
    lines.append("-" * 55)

    count = 0
    for m in common:
        if m.startswith('Y'):
            continue
        if m == 'ENO':
            continue
        lines.append(f"{m:<12} {ecoli_dict[m]:<20} {bacsu_dict[m]:<20}")
        count += 1

    lines.append(f"\nTotal (after exclusions): {count}")

    output = "\n".join(lines)
    print(output)

    with open("common_pairs.txt", "w", encoding="utf-8") as f:
        f.write(output)
    print("\nResults saved to common_pairs.txt")

if __name__ == "__main__":
    main()
