Source code for skchem.data.converters.muller_ames
#! /usr/bin/env python
#
# Copyright (C) 2016 Rich Lewis <rl403@cam.ac.uk>
# License: 3-clause BSD
import os
import zipfile
import logging
LOGGER = logging.getLogger(__name__)
import pandas as pd
import numpy as np
import skchem
from .base import Converter
from ... import standardizers
PATCHES = {
'820-75-7': r'NNC(=O)CNC(=O)C=[N+]=[N-]',
'2435-76-9': r'[N-]=[N+]=C1C=NC(=O)NC1=O',
'817-99-2': r'NC(=O)CNC(=O)\C=[N+]=[N-]',
'116539-70-9': r'CCCCN(CC(O)C1=C\C(=[N+]=[N-])\C(=O)C=C1)N=O',
'115-02-6': r'NC(COC(=O)\C=[N+]=[N-])C(=O)O',
'122341-55-3': r'NC(COC(=O)\C=[N+]=[N-])C(=O)O'
}
[docs]class MullerAmesConverter(Converter):
def __init__(self, directory, output_directory, output_filename='muller_ames.h5'):
"""
Args:
directory (str):
Directory in which input files reside.
output_directory (str):
Directory in which to save the converted dataset.
output_filename (str):
Name of the saved dataset. Defaults to `muller_ames.h5`.
Returns:
tuple of str:
Single-element tuple containing the path to the converted dataset.
"""
zip_path = os.path.join(directory, 'ci900161g_si_001.zip')
output_path = os.path.join(output_directory, output_filename)
with zipfile.ZipFile(zip_path) as f:
f.extractall()
# create dataframe
data = pd.read_csv(os.path.join(directory, 'smiles_cas_N6512.smi'),
delimiter='\t', index_col=1,
converters={1: lambda s: s.strip()},
header=None, names=['structure', 'id', 'is_mutagen'])
data = self.patch_data(data, PATCHES)
data['structure'] = data.structure.apply(skchem.Mol.from_smiles)
data = self.standardize(data)
data = self.optimize(data)
keep = self.filter(data)
ms, ys = keep.structure, keep.is_mutagen
indices = data.reset_index().index.difference(keep.reset_index().index)
train = self.parse_splits(os.path.join('splits_train_N6512.csv'))
train = self.drop_indices(train, indices)
splits = self.create_split_dict(train, 'train')
test = self.parse_splits(os.path.join(directory, 'splits_test_N6512.csv'))
test = self.drop_indices(test, indices)
splits.update(self.create_split_dict(test, 'test'))
self.run(ms, ys, output_path, splits=splits)
[docs] def patch_data(self, data, patches):
""" Patch smiles in a DataFrame with rewritten ones that specify diazo
groups in rdkit friendly way. """
LOGGER.info('Patching data...')
for cas, smiles in patches.items():
data.loc[cas, 'structure'] = smiles
return data
[docs] def parse_splits(self, f_path):
LOGGER.info('Parsing splits...')
with open(f_path) as f:
splits = [split for split in f.read().strip().splitlines()]
splits = [[n for n in split.strip().split(',')] for split in splits]
splits = [sorted(int(n) for n in split) for split in splits] # sorted ints
return [np.array(split) - 1 for split in splits] # zero based indexing
[docs] def drop_indices(self, splits, indices):
LOGGER.info('Dropping failed compounds from split indices...')
for i, split in enumerate(splits):
split = split - sum(split > ix for ix in indices)
splits[i] = np.delete(split, indices)
return splits
[docs] def create_split_dict(self, splits, name):
return {'{}_{}'.format(name, i + 1): split \
for i, split in enumerate(splits)}
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
LOGGER.info('Converting Muller Ames Dataset...')
MullerAmesConverter.convert()