In [1]:
import json
import pickle
import sys
from pathlib import Path
from collections import Counter
from rdkit import Chem
from IPython import display

# hack to import parent packages
sys.path.append(str(Path.cwd().parent))

from data_utils.evaluator import get_amidation_rxns, get_p_at_k, count_implicit_valence_N
from utils.mol_utils import remove_isotope

In [2]:
sub_result_dict = pickle.load(open('../data/sub_result_dict.pkl', 'rb'))
rxns = get_amidation_rxns(sub_result_dict.keys())
len(rxns)

100%|██████████| 82932/82932 [01:25<00:00, 967.80it/s] 


1154

In [3]:
new_sub_result_dict = {}
for rxn in rxns:
    new_sub_result_dict[rxn] = [p[0] for p in sub_result_dict[rxn]]

In [4]:
get_p_at_k(new_sub_result_dict)

top 1 accuracy:  60.49 %
top 2 accuracy:  71.32 %
top 3 accuracy:  76.17 %
top 4 accuracy:  79.12 %
top 5 accuracy:  80.16 %
top 6 accuracy:  80.76 %
top 7 accuracy:  81.37 %
top 8 accuracy:  81.80 %
top 9 accuracy:  82.41 %
top 10 accuracy:  82.58 %


In [5]:
total_sub, correct_sub = 0, 0
for rxn in rxns:
    all_subs = set()
    sub2count = {}
    for pred in sub_result_dict[rxn]:
        _, _, _, sub_exists_rankings = pred
        cur_subs_list = [(ele[0], ele[1]) for ele in sub_exists_rankings]
        all_subs.update(cur_subs_list)
        counter = Counter(cur_subs_list)
        for ele_count in counter.most_common():
            ele, count = ele_count
            if (ele in sub2count and sub2count[ele] < count) or (ele not in sub2count):
                sub2count[ele] = count
    total_sub += sum(sub2count.values())
    correct_sub += sum([ele[1] for ele in sub2count.items() if ele[0][1]])      
print(correct_sub/total_sub)

0.9061313387496597


In [6]:
blocked_inactive_N_in_sub = 0
not_blocked_inactive_N_in_sub = 0
blocked_rxns, unblocked_rxns = [], []
for rxn in rxns:
    src, tgt = rxn
    corr_idx = -1
    for idx, pred in enumerate(new_sub_result_dict[rxn]):
        if idx >= 10:
            continue
        if tgt == pred:
            corr_idx = idx    
    is_blocked = False
    product, reactants = rxn
    reactants_mol = Chem.MolFromSmiles(reactants)
    correct_subs = set()
    for pred in sub_result_dict[rxn]:
        correct_subs.update([ele[0] for ele in pred[-1] if ele[1]])
    
    for sub in correct_subs:       
        sub_mol = remove_isotope(Chem.MolFromSmiles(sub))        
        total_N, v2_N, v1_N, amide_num = count_implicit_valence_N(reactants_mol, idset=reactants_mol.GetSubstructMatch(sub_mol))
        if v1_N > 0:
            blocked_inactive_N_in_sub +=1
            is_blocked = True
            blocked_rxns.append(rxn)
            break
    if not is_blocked:
        unblocked_rxns.append(rxn)
        not_blocked_inactive_N_in_sub+=1      


In [7]:
blocked_inactive_N_in_sub, blocked_inactive_N_in_sub/len(rxns)

(665, 0.5762564991334489)

In [8]:
block_sub_result_dict = {}
for rxn in blocked_rxns:
    block_sub_result_dict[rxn] = [p[0] for p in sub_result_dict[rxn]]
get_p_at_k(block_sub_result_dict, n_best=10)
unblock_sub_result_dict = {}
print('\n')
for rxn in unblocked_rxns:
    unblock_sub_result_dict[rxn] = [p[0] for p in sub_result_dict[rxn]]
get_p_at_k(unblock_sub_result_dict, n_best=10)

top 1 accuracy:  67.67 %
top 2 accuracy:  78.35 %
top 3 accuracy:  81.95 %
top 4 accuracy:  84.66 %
top 5 accuracy:  85.56 %
top 6 accuracy:  86.32 %
top 7 accuracy:  86.92 %
top 8 accuracy:  87.07 %
top 9 accuracy:  87.37 %
top 10 accuracy:  87.52 %


top 1 accuracy:  50.72 %
top 2 accuracy:  61.76 %
top 3 accuracy:  68.30 %
top 4 accuracy:  71.57 %
top 5 accuracy:  72.80 %
top 6 accuracy:  73.21 %
top 7 accuracy:  73.82 %
top 8 accuracy:  74.64 %
top 9 accuracy:  75.66 %
top 10 accuracy:  75.87 %
