MYSQL表结构对比

import pymysql

class SchemaMysql:
    def __init__(self, src_info, des_info, isdrop):
        self.src_ip = src_info[0]
        self.src_db_user = src_info[1]
        self.src_db_pass = src_info[2]
        self.src_db_port = src_info[3]
        self.des_ip = des_info[0]
        self.des_db_user = des_info[1]
        self.des_db_pass = des_info[2]
        self.des_db_port = des_info[3]
        self.isDrop = isdrop

    def init(self, src_db, des_db):
        self.src_db = src_db
        self.dbsrc = self.connect_db(self.src_ip, src_db, self.src_db_user, self.src_db_pass, self.src_db_port)
        self.dbdes = self.connect_db(self.des_ip, des_db, self.des_db_user, self.des_db_pass, self.des_db_port)
        self.cursorsrc = self.dbsrc.cursor()
        self.cursordes = self.dbdes.cursor()

    def close_db(self):
        self.cursorsrc.close()
        self.cursordes.close()
        self.dbsrc.close()
        self.dbdes.close()

    def connect_db(self, ip, db_name, db_user, db_pass, db_port):
        config = {
            'host': ip,
            'port': int(db_port),
            'user': db_user,
            'password': db_pass,
            'db': db_name,
        }
        db = pymysql.connect(**config)
        return db

    def getTable(self, cursor):
        cursor.execute("show full tables where Table_type = 'BASE TABLE';")
        tablelist = []
        for tb in cursor.fetchall():
            tablelist.append(tb[0])
        return tablelist

    def getField(self, cursor, tb):
        sql = "show create table "+str(tb)
        cursor.execute(sql)
        return cursor.fetchone()

    def create_table(self, cursor, tb):
        for i in tb:
            sql = "show create table "+str(i)
            cursor.execute(sql)
            create_table_sql = cursor.fetchone()
            print(create_table_sql[1])
            self.cursordes.execute(create_table_sql[1])

    def field_del_add(self, tb, des_field_table, src_field_table):
        des_del = []
        src_add = []
        runlist = []
        src_pwd = {}
        if len(des_field_table) > 0:
            for i in des_field_table:
                des_del.append(i.split(' ')[2])
        if len(src_field_table) > 0:
            for i in src_field_table:
                src_add.append(i.split(' ')[2])
        for n in src_field_table:
            src_pwd[n.split(' ')[2]] = n
        #删除表字段
        if len(des_del) > 0:
            delF = (set(des_del).difference(set(src_add)))
            for i in delF:
                sql = "ALTER TABLE `%s` DROP COLUMN %s" % (tb, i)
                if sql[-1] == ',':
                    sql = sql[:-1]
                runlist.append(sql)
        #增加表字段
        addF = set(src_add).difference(set(des_del))
        if len(addF) > 0:
            for i in addF:
                sql = "ALTER TABLE `%s` ADD COLUMN %s" % (tb, src_pwd[i])
                if sql[-1] == ',':
                    sql = sql[:-1]
                runlist.append(sql)
        #更改表字段
        changeF = set(des_del).intersection(src_add)  # 获取差异字段信息
        if len(changeF) > 0:
            for i in changeF:
                sql = "ALTER TABLE `%s` MODIFY COLUMN %s" % (tb, src_pwd[i])
                if sql[-1] == ',':
                    sql = sql[:-1]
                runlist.append(sql)
        return runlist

    def runExec(self, runlist):
        for sql in runlist:
            print('TABLE FIELD:  %s' % (sql))
            #self.cursordes.execute(sql)

    def index_del_add(self, tb, des_field_index, src_field_index):
        runlist = []
        for i in des_field_index:
            sql = "ALTER TABLE `syspara` DROP INDEX %s" % (i)
            runlist.append(sql)
        for i in src_field_index:
            sql = "ALTER TABLE `syspara` ADD INDEX %s" % (i)
            runlist.append(sql)
        return runlist
    #ALTER TABLE `syspara` ADD INDEX `qwe`(`paravalue`)
    #ALTER TABLE `syspara` ADD UNIQUE INDEX `cadd`(`paravalue`) USING HASH
    #ALTER TABLE `syspara` DROP INDEX `test`

    def checkField(self, tb):
        runlist = []
        sqllist = []
        des_field_index = []
        des_field_table = []
        src_field_index = []
        src_field_table = []
        if tb != 'dual':
            des_field = self.getField(self.cursordes, tb)
            src_field = self.getField(self.cursorsrc, tb)
            if des_field[0] == src_field[0]:
                srcList = src_field[1].split("\n")[1:-1]
                desList = des_field[1].split("\n")[1:-1]
                des_diff = set(desList).difference(set(srcList))
                src_diff = set(srcList).difference(set(desList))
                for i in des_diff:
                    if 'KEY' in i or 'CONSTRAINT' in i or 'USING BTREE' in i:
                        des_field_index.append(i)
                    else:
                        des_field_table.append(i)
                for i in src_diff:
                    if 'KEY' in i:
                        src_field_index.append(i)
                    else:
                        src_field_table.append(i)
                table_sql = self.field_del_add(tb, des_field_table, src_field_table)
                index_sql = self.index_del_add(tb, des_field_index, src_field_index)
                sqllist = table_sql+index_sql
                self.runExec(sqllist)

    def checkTable(self, tbSrc, tbDes, cursrc, curdes):
        table_del = list(set(tbDes).difference(set(tbSrc)))
        table_add = list(set(tbSrc).difference(set(tbDes)))
        table_pub = list(set(tbSrc).intersection(set(tbDes)))
        if len(table_add) > 0:
            print('ADD TABLE:  %s' % (table_add))
            self.create_table(self.cursorsrc, table_add)
        if len(table_del) > 0:
            for i in table_del:
                sql = ('drop table %s') % (i)
                print(sql)
                self.cursordes.execute(sql)
        for tb in table_pub:
            self.checkField(tb)

    def run(self):
        srctable = self.getTable(self.cursorsrc)
        destable = self.getTable(self.cursordes)
        self.checkTable(srctable, destable, self.cursorsrc, self.cursordes)


if __name__ == '__main__':
    src_info = ['172.17.13.81','root','123456','3306']
    des_info = ['172.17.13.51', 'root', '123456','3306']
    src_db = ['test']
    des_db = ['test']
    for i in range(len(src_db)):
        if src_db[i] and des_db[i]:
            print('\n\n-------------------------------------------------------------%s-----------------------------------------------------------\n' % (src_db[i]))
            schema = SchemaMysql(src_info, des_info, True)
            schema.init(src_db[i], des_db[i])
            schema.run()
            schema.close_db()