import sqlite3
import warnings
import subprocess as sp
import os
import sys
import numpy as np
import pandas as pd
from pathlib import Path

from .pdb2sql_base import pdb2sql_base

[docs]class pdb2sql(pdb2sql_base): def __init__(self, pdbfile, tablename='atom', **kwargs): """Create a SQL database with PDB data. Notes: Only "ATOM" data of PDB is parsed, other items e.g. HETATM are not ignored. Args: pdbfile(str, list, ndarray): pdb file or data Examples: >>> db = pdb2sql.pdb2sql('3CRO.pdb') """ super().__init__(pdbfile, **kwargs) # create the database self._create_sql() self._create_table(pdbfile, tablename=tablename) # fix the chain ID if self.fix_chainID: self._fix_chainID() def __call__(self, **kwargs): """Returns an pdb2sql instance of the selected parts Args: kwargs: argument to select atoms, dict value must be list, e.g.: - name = ['CA', 'O'] - no_name = ['CA', 'C'] - chainID = ['A'] - no_chainID = ['A'] Returns: pdb2sql: an pb2sql instance Examples: >>> sqldb = pdb2sql('1AK4.pdb') >>> dbA = sqldb(chainID='A') """ names = self._get_table_names() if len(names) > 1: warnings.warn('pdbsql is meant for single structure. \ To use multiple structures use many2sql. \ This call will only return the data of \ the first table : ', names[0]) pdb_data = self.sql2pdb(tablename=names[0], **kwargs) new_db = pdb2sql(pdb_data, tablename=names[0]) return new_db def __repr__(self): return f'{self.__module__}.{self.__class__.__name__} object' def _create_sql(self, tablename='ATOM'): """Create a sql database containg a model PDB.""" sqlfile = self.sqlfile if self.verbose: print('-- Create SQLite3 database') # open the data base # if we do not specify a db name # the db is only in RAM if self.sqlfile is None: self.conn = sqlite3.connect(':memory:') # or we create a new db file else: if os.path.isfile(sqlfile):'rm %s' % sqlfile, shell=True) self.conn = sqlite3.connect(sqlfile) self.c = self.conn.cursor() def _create_table(self, pdbfile, tablename='ATOM'): # size of the things ncol = len(self.col) # clean up tablename for c in "!@#$%^&*()[]{};:,./<>?\|`~-=_+": tablename = tablename.replace(c, '_') # intialize the header/placeholder header, qm = '', '' for ic, (colname, coltype) in enumerate(self.col.items()): header += '{cn} {ct}'.format(cn=colname, ct=coltype) qm += '?' if ic < ncol - 1: header += ', ' qm += ',' # create the table query = 'CREATE TABLE {tablename} ({hd})'.format(tablename=tablename, hd=header) self.c.execute(query) # get pdb data pdbdata = pdb2sql.read_pdb(pdbfile) self._nModel = 0 data_atom = [] for line in pdbdata: line = str(line) if line.startswith('ATOM'): line = line.split('\n')[0] elif line.startswith('ENDMDL'): self._nModel += 1 continue else: continue # check pdb line format line = pdb2sql._format_pdb_linelength(line) # browse all attribute of each atom at = () for colname, coltype in self.col.items(): # get the piece of data if colname in self.delimiter.keys(): data = line[self.delimiter[colname][0]: self.delimiter[colname][1]].strip() # check pdb format and reset values if necessary # Empty chainID, occ, temp and element are not allowed if not data: if colname == "chainID": data = pdb2sql._get_chainID(line) if colname == "occ": data = 1.00 if colname == "temp": data = 10.00 if colname == "element": data = pdb2sql._get_element(line) # convert it if necessary if coltype == 'INT': data = int(data) elif coltype == 'REAL': data = float(data) # append keep the comma !! # we need proper tuple at += (data,) # append the model number at += (self._nModel,) # append data_atom.append(at) # push in the database self.c.executemany( 'INSERT INTO {tablename} VALUES ({qm})'.format(tablename=tablename, qm=qm), data_atom)
[docs] @staticmethod def read_pdb(pdbfile): """Read pdb file or data to a list. Args: pdbfile(str, list or ndarray): pdb file or data Raises: FileNotFoundError: pdb file not found ValueError: invalid input Returns: list: pdb content in list format Examples: >>> pdb2sql.read_pdb('3CRO.pdb') """ # read the pdb file a pure python way # RMK we go through the data twice here. Once to read the ATOM line and once to parse the data ... # we could do better than that. But the most time consuming step seems # to be the CREATE TABLE query if isinstance(pdbfile, bytes): pdbfile = pdbfile.decode() if isinstance(pdbfile, str): if os.path.exists(pdbfile): if os.path.isfile(pdbfile): with open(pdbfile, 'r') as fi: pdbdata = fi.readlines() else: raise FileNotFoundError( f'{pdbfile} is not a file') else: # input is pdb content if pdbfile.count('\nATOM ') > 3: pdbdata = pdbfile.split('\n') # invalid path else: raise FileNotFoundError( f'File not found: {pdbfile}') elif isinstance(pdbfile, Path): if not pdbfile.exists(): raise FileNotFoundError(f'File not found: {pdbfile}') elif pdbfile.is_file(): with as fi: pdbdata = fi.readlines() else: raise FileNotFoundError(f'{pdbfile} is not a file') elif isinstance(pdbfile, list): if isinstance(pdbfile[0], str): pdbdata = pdbfile elif isinstance(pdbfile[0], bytes): pdbdata = [line.decode() for line in pdbfile] else: raise ValueError(f'Invalid pdb input: {pdbfile}') elif isinstance(pdbfile, np.ndarray): pdbfile = pdbfile.tolist() if isinstance(pdbfile[0], str): pdbdata = pdbfile elif isinstance(pdbfile[0], bytes): pdbdata = [line.decode() for line in pdbfile] else: raise ValueError(f'Invalid pdb input: {pdbfile}') else: raise ValueError(f'Invalid pdb input: {pdbfile}') return pdbdata
@staticmethod def _format_pdb_linelength(pdb_line): linelen = len(pdb_line) if linelen < 80: pdb_line = pdb_line + ' ' * (80 - linelen) elif linelen > 80: raise ValueError( f'pdb line is longer than 80:\n{pdb_line}') return pdb_line @staticmethod def _get_chainID(pdb_line): segID_ind = [72, 76] # segID columns in pdb segID = pdb_line[segID_ind[0]:segID_ind[1]].strip() if segID: warnings.warn("Missing chainID and set it with segID") return segID else: raise ValueError('chainID not found') @staticmethod def _get_element(pdb_line): """Get element type from the atom type of a pdb line. Notes: Atom type occupies 13-16th columns of a PDB line. Four situations exist: 13 14 15 16 C A The element is C C A The element is Ca 1 H G The element is H H E 2 1 The element is H Args: pdb_line(str): one PDB ATOM line Returns: str : element name """ first_char = pdb_line[12].strip() last_char = pdb_line[15].strip() if first_char: if first_char in "0123456789": elem = pdb_line[13] elif first_char == "H" and last_char: elem = "H" else: elem = pdb_line[12:14] else: elem = pdb_line[13] # warnings.warn(f'Element is missing and guessed using atom type for line\n {pdb_line}') return elem # replace the chain ID by A,B,C,D, ..... in that order def _fix_chainID(self): from string import ascii_uppercase # get the current names chainID = self.get('chainID') natom = len(chainID) chainID = sorted(set(chainID)) if len(chainID) > 26: warnings.warn( "More than 26 chains have been detected, not supported so far") sys.exit(1) # declare the new names newID = [''] * natom # fill in the new names for ic, chain in enumerate(chainID): index = self.get('rowID', chainID=chain) for ind in index: newID[ind] = ascii_uppercase[ic] # update the new name self.update_column('chainID', newID) # get the names of the columns
[docs] def get_colnames(self): """Get SQL column names. Returns: list: all available column names Examples: >>> db.get_colnames() """ tablename = self._get_table_names()[0] cd = self.conn.execute( 'select * from {tablename}'.format(tablename=tablename)) names = list(map(lambda x: x[0], cd.description)) names = ['rowID'] + names return names
[docs] def print_colnames(self): """Print out SQL column names. Examples: >>> db.print_colnames() """ tablenames = self._get_table_names() names = self.get_colnames() print('Possible column names are:') for n in names: print('\t' + n)
[docs] def print(self, columns='*', tablename='atom', **kwargs): """Print out SQL ATOM table. Notes: Float number is stored in original precision. It will be formatted properly when output pdb with `exportpdb` method. Args: columns (str): columns to retreive, eg: "x,y,z". if "*" all the columns are returned. Avaiable columns: serial, name, altLoc, resName, chainID, resSeq, iCode, x, y, z, occ, temp, element, model kwargs: argument to select atoms, dict value must be list, e.g.: - name = ['CA', 'O'] - no_name = ['CA', 'C'] - chainID = ['A'] - no_chainID = ['A'] Examples: >>> db.print() """ data = self.get(columns, tablename=tablename, **kwargs) arr = np.array(data) if len(arr.shape) == 2: df = pd.DataFrame(arr) elif len(arr.shape) == 3: arr = arr.reshape((-1, arr.shape[2])) df = pd.DataFrame(arr) if columns == '*': cd = self.conn.execute( 'select * from {tablename}'.format(tablename=tablename)) names = list(map(lambda x: x[0], cd.description)) df.columns = names else: df.columns = columns.split(',') print(df.to_csv(sep='\t', index=False))
# get the properties
[docs] def get(self, columns, tablename='ATOM', **kwargs): """Exectute simple SQL query to extract values. Args: columns (str): columns to retreive, eg: "x,y,z". if "*" all the columns are returned. Avaiable columns: serial, name, altLoc, resName, chainID, resSeq, iCode, x, y, z, occ, temp, element, model kwargs: argument to select atoms, dict value must be list, e.g.: - name = ['CA', 'O'] - no_name = ['CA', 'C'] - chainID = ['A'] - no_chainID = ['A'] Returns: data: list containing the value of the attributes Examples: >>> db.get('x,y,z', chainID=['A'], no_resName=['ALA', 'TRP']) """ # check arguments format valid_colnames = self.get_colnames() if not isinstance(columns, str): raise TypeError("argument columns must be str") if columns != '*': for i in columns.split(','): if i.strip() not in valid_colnames: raise ValueError( f'Invalid column name {i}. Possible names are\n' f'{self.get_colnames()}') # the asked keys keys = kwargs.keys() if 'model' not in kwargs.keys() and self._nModel > 0: model_data = [] for iModel in range(self._nModel): kwargs['model'] = iModel model_data.append(self.get(columns, **kwargs)) return model_data # if we have 0 key we take the entire db if len(kwargs) == 0: query = 'SELECT {an} FROM {tablename}'.format( an=columns, tablename=tablename) data = [list(row) for row in self.c.execute(query)] ####################################################################### # GENERIC QUERY # # each keys must be a valid columns # each valu may be a single value or an array # AND is assumed between different keys # OR is assumed for the different values of a given key # ####################################################################### else: # check that all the keys exists for k in keys: if k.startswith('no_'): k = k[3:] try: self.c.execute( "SELECT EXISTS(SELECT {an} FROM {tablename})".format( an=k, tablename=tablename)) except BaseException: raise ValueError( f'Invalid column name {k}. Possible names are\n' f'{self.get_colnames()}') # form the query and the tuple value query = 'SELECT {an} FROM {tablename} WHERE '.format( an=columns, tablename=tablename) conditions = [] vals = () # iterate through the kwargs for _, (k, v) in enumerate(kwargs.items()): # deals with negative conditions if k.startswith('no_'): k = k[3:] neg = ' NOT' else: neg = '' # get if we have an array or a scalar # and build the value tuple for the sql query # deal with the indexing issue if rowID is required if isinstance(v, list): nv = len(v) # if we have a large number of values # we must cut that in pieces because SQL has a hard limit # that is 999. The limit is here set to 950 # so that we can have multiple conditions with a total number # of values inferior to 999 if nv > self.max_sql_values: # cut in chunck chunck_size = self.max_sql_values vchunck = [v[i:i + chunck_size] for i in range(0, nv, chunck_size)] data = [] for v in vchunck: new_kwargs = kwargs.copy() new_kwargs[k] = v data += self.get(columns, **new_kwargs) return data # otherwise we just go on else: if k == 'rowID': vals = vals + \ tuple([int(iv + 1) for iv in v]) else: vals = vals + tuple(v) else: nv = 1 if k == 'rowID': vals = vals + (int(v + 1),) else: vals = vals + (v,) # create the condition for that key conditions.append( k + neg + ' in (' + ','.join('?' * nv) + ')') # stitch the conditions and append to the query query += ' AND '.join(conditions) # error if vals is too long if len(vals) > self.SQLITE_LIMIT_VARIABLE_NUMBER: print( '\nError : SQL Queries can only handle a total of 999 values') print(' : The current query has %d values' % len(vals)) print(' : Hence it will fails.') print( ' : You are in a rare situation where MULTIPLE conditions have') print( ' : have a combined number of values that are too large') print(' : These conditions are:') ntot = 0 for k, v in kwargs.items(): print(' : --> %10s : %d values' % (k, len(v))) ntot += len(v) print(' : --> %10s : %d values' % ('Total', ntot)) print( ' : Try to decrease max_sql_values in\n') raise ValueError('Too many SQL variables') # query the sql database and return the answer in a list data = [list(row) for row in self.c.execute(query, vals)] # empty data if len(data) == 0: # warnings.warn('SQL query get an empty') return data # fix the python <--> sql indexes # if atnames == 'rowID': if 'rowID' in columns: index = columns.split(',').index('rowID') for i in range(len(data)): data[i][index] -= 1 # postporcess the output of the SQl query # flatten it if each els is of size 1 if len(data[0]) == 1: data = [d[0] for d in data] return data
[docs] def update(self, columns, values, tablename='ATOM', **kwargs): """Update the database with given values. Args: columns (str): names of column to update, e.g. "x,y,z". Avaiable columns: serial, name, altLoc, resName, chainID, resSeq, iCode, x, y, z, occ, temp, element, model values (np.ndarray): an array of values that corresponds to the number of columns and atoms selected. kwargs: selection arguments, eg: name = ['CA', 'O'], chainID = ['A'], or no_name = ['CA', 'C'], no_chainID = ['A']. Examples: >>> values = np.array([[1.,2.,3.], [4.,5.,6.]]) >>> db.update("x,y,z", values=values, resName='MET', name=['CA', 'CB']) """ # check arguments format valid_colnames = self.get_colnames() if not isinstance(columns, str): raise TypeError("argument columns must be str") if columns != '*': for i in columns.split(','): if i not in valid_colnames: raise ValueError( f'Invalid column name {i}. Possible names are\n' f'{self.get_colnames()}') # the asked keys keys = kwargs.keys() # handle the multi model cases if 'model' not in keys and self._nModel > 0: for iModel in range(self._nModel): kwargs['model'] = iModel self.update(columns, values, **kwargs) return # parse the attribute if ',' in columns: columns = columns.split(',') if not isinstance(columns, list): columns = [columns] # check the size natt = len(columns) nrow = len(values) ncol = len(values[0]) if natt != ncol: raise ValueError( 'Number of cloumns does not match between argument columns and values') # get the row ID of the selection rowID = self.get('rowID', **kwargs) nselect = len(rowID) if nselect != nrow: raise ValueError( 'Number of data values incompatible with the given conditions') # prepare the query query = 'UPDATE {tablename} SET '.format(tablename=tablename) query = query + ', '.join(map(lambda x: x + '=?', columns)) query = query + ' WHERE rowID=?' # prepare the data data = [] for i, val in enumerate(values): tmp_data = [v for v in val] # here the conversion of the indexes is a bit annoying tmp_data += [rowID[i] + 1] data.append(tmp_data) self.c.executemany(query, data)
[docs] def update_column(self, colname, values, index=None, tablename='ATOM'): """Update a single column. Args: colname (str): name of the column to update Avaiable columns: serial, name, altLoc, resName, chainID, resSeq, iCode, x, y, z, occ, temp, element, model values (list): new values of the column index (None, optional): index of the column to update (default all) Example: >>> db.update_column('x',np.random.rand(10),index=list(range(10))) """ if index is None: data = [[v, i + 1] for i, v in enumerate(values)] else: data = [[v, ind + 1] for v, ind in zip(values, index)] query = 'UPDATE {tablename} SET {cn}=? WHERE rowID=?'.format(tablename=tablename, cn=colname) self.c.executemany(query, data)
[docs] def add_column(self, colname, coltype='FLOAT', value=0, tablename='ATOM'): """Add an new column to the ATOM table with same value for each row. Args: colname (str): the new column name coltype (str): data type of the new column. Default to float. value (int/float/str): value for the new column. Default to 0. Examples: >>> db.add_column('id', coltype='str', value='positive') """ query = "ALTER TABLE %s ADD COLUMN '%s' %s DEFAULT %s" % (tablename, colname, coltype, str(value)) self.c.execute(query)
def _commit(self): """Commit to the database.""" self.conn.commit()