Repository: giantbranch/mipsAudit Branch: master Commit: 8650a97bb410 Files: 5 Total size: 116.1 KB Directory structure: gitextract_w6dphd8d/ ├── .github/ │ └── FUNDING.yml ├── README.md ├── mipsAudit.py └── oldversion(py2)/ ├── mipsAudit.py └── prettytable.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username custom: ['http://pic.giantbranch.cn/pic/1551450728861.jpg'] ================================================ FILE: README.md ================================================ # IDAPython mipsAudit ## 简介 这是一个简单的IDAPython脚本。 进一步来说是MIPS静态汇编审计辅助脚本。 可能会有bug,欢迎大家完善。 > **v3.0 更新**: 支持 IDA 7.x,8.x API 和 Python 3,新增多项高级漏洞检测功能。 ## 功能 ### 基础功能 1. 找到危险函数的调用处,并且高亮该行(也可以下断点,这个需要自己去源码看吧) 2. 给参数赋值处加上注释 3. 最后以表格的形式输出函数名,调用地址,参数,还有当前函数的缓冲区大小 **大家双击addr那一列的地址,即可跳到对应的地址处** ![17cc62c98820974f8c759dc086dd5acb](17cc62c98820974f8c759dc086dd5acb.png) ![28069d48cf3f357dd83e42406e10d980](28069d48cf3f357dd83e42406e10d980.png) ### v3.0 新增功能 #### 高级漏洞检测 | 检测类型 | 说明 | |---------|------| | **命令注入检测** | 追踪 system/popen/execve 参数是否来自外部输入 | | **栈溢出检测** | 比较目标缓冲区大小与源数据长度 | | **格式化字符串漏洞** | 检测 %n 写入原语和用户可控格式字符串 | | **整数溢出检测** | 检测 malloc/calloc 的 size 参数来源和算术运算 | | **双重释放检测** | 追踪 free() 调用,检测同一指针多次释放 | | **数据流分析** | 追踪 read/recv 返回数据流向危险函数 | | **Wrapper函数识别** | 识别封装了危险函数的自定义函数 | #### 风险等级高亮 | 颜色 | 等级 | 说明 | |------|------|------| | 红色 | HIGH | 需要立即关注 | | 橙色 | MEDIUM | 需要人工复核 | | 绿色 | LOW | 信息提示 | #### 结果导出 自动生成带时间戳的 HTML 报告,保存到 IDB 文件所在目录: ``` mipsAudit_results_20260129_143052.html ``` #### 配置文件支持 支持通过 `mipsAudit_config.json` 扩展自定义函数列表: ```json { "dangerous_functions": ["custom_strcpy"], "command_execution_function": ["custom_exec"], "_comment": "扩展默认函数列表" } ``` ## 审计的危险函数 ```python dangerous_functions = [ "strcpy", "strcat", "sprintf", "read", "getenv", "gets", "scanf", "vscanf", "realpath", "access", "stat", "lstat" ] attention_function = [ "memcpy", "strncpy", "sscanf", "strncat", "snprintf", "vprintf", "printf", "fprintf", "vfprintf", "vsprintf", "vsnprintf", "syslog", "memmove", "bcopy" ] command_execution_function = [ "system", "execve", "popen", "unlink", "execl", "execle", "execlp", "execv", "execvp", "dlopen", "mmap", "mprotect" ] memory_alloc_functions = [ "malloc", "calloc", "realloc", "memalign", "valloc", "pvalloc", "aligned_alloc", "mmap" ] memory_free_functions = [ "free", "cfree", "munmap" ] ``` ## 运行流程 ``` PHASE 1: Basic Function Audit # 基础危险函数审计 PHASE 2: Enhanced Vulnerability Detection # 增强漏洞检测 PHASE 3: Advanced Analysis # 高级分析(数据流、Wrapper识别) PHASE 4: Results Summary & Export # 结果汇总与导出 ``` ## 使用 ### 环境要求 - IDA Pro 7.0+ - Python 3.x(IDA 内置) ### 运行方式 File - Script file ![1561006651468](./1561006651468.png) 选择mipsAudit.py ![1561006737134](./1561006737134.png) 即可看到效果 ![mipsAudit](./mipsAudit.png) 双击地址即可跳到对应的代码处 ![1561006887117](./1561006887117.png) ## 更新日志 ### v3.0 (2026-01) - 支持 IDA 7.x+ API - 迁移至 Python 3 - 新增格式化字符串漏洞检测(%n 检测) - 新增命令注入参数来源追踪 - 新增栈溢出缓冲区大小比较 - 新增整数溢出检测(malloc/calloc size 参数) - 新增双重释放检测 - 新增数据流分析(read/recv 返回值追踪) - 新增 Wrapper 函数识别 - 新增基本块分析(修复跨基本块误判) - 新增 HTML 报告导出(带时间戳) - 新增外部 JSON 配置文件支持 - 新增扫描进度显示 - 扩展危险函数列表 ### v1.0 (2018-05) - 初始版本 by giantbranch ================================================ FILE: mipsAudit.py ================================================ # -*- coding: utf-8 -*- # reference # 《ida pro 权威指南》 # 《python 灰帽子》 # 《家用路由器0day漏洞挖掘》 # https://github.com/wangzery/SearchOverflow/blob/master/SearchOverflow.py # Updated for IDA 7.x+ API and Python 3 # Enhanced with advanced vulnerability detection v3.0 import idc import idaapi import idautils from prettytable import PrettyTable from collections import defaultdict import json import csv import os import re from datetime import datetime # IDA 7.x+ uses BADADDR from idaapi BADADDR = idaapi.BADADDR DEBUG = True # ============================================================ # Configuration - Can be overridden by external config file # ============================================================ CONFIG_FILE = "mipsAudit_config.json" # Default function lists (can be extended via config file) dangerous_functions = [ "strcpy", "strcat", "sprintf", "read", "getenv", "gets", # No boundary check - extremely dangerous "scanf", # Format input vulnerability "vscanf", "realpath", # Path traversal "access", # TOCTOU race condition "stat", # TOCTOU race condition "lstat", ] attention_function = [ "memcpy", "strncpy", "sscanf", "strncat", "snprintf", "vprintf", "printf", "fprintf", "vfprintf", "vsprintf", "vsnprintf", "syslog", "memmove", "bcopy", ] command_execution_function = [ "system", "execve", "popen", "unlink", "execl", "execle", "execlp", "execv", "execvp", "dlopen", "mmap", # Memory mapping "mprotect", # Change memory protection ] # External input source functions (for taint tracking) external_input_functions = [ "getenv", "read", "recv", "recvfrom", "recvmsg", "fgets", "fread", "fgetc", "gets", "scanf", "fscanf", "getchar", "getc", "fgetws", "getwchar", "getline", "getdelim", "socket", "accept", "gethostbyname", ] # Memory management functions memory_alloc_functions = [ "malloc", "calloc", "realloc", "memalign", "valloc", "pvalloc", "aligned_alloc", "mmap", ] memory_free_functions = [ "free", "cfree", "munmap", ] # Format string functions (for %n detection) format_string_functions = [ "printf", "fprintf", "sprintf", "snprintf", "vprintf", "vfprintf", "vsprintf", "vsnprintf", "syslog", ] # describe arg num of function one_arg_function = [ "getenv", "system", "unlink", "free", "cfree", "malloc", "gets", ] two_arg_function = [ "strcpy", "strcat", "popen", "calloc", "dlopen", "fgets", "access", "stat", "lstat", "realpath", "mprotect", ] three_arg_function = [ "strncpy", "strncat", "memcpy", "memmove", "bcopy", "execve", "read", "recv", "fread", "realloc", ] format_function_offset_dict = { "sprintf": 1, "sscanf": 1, "snprintf": 2, "vprintf": 0, "printf": 0, "fprintf": 1, "vfprintf": 1, "vsprintf": 1, "vsnprintf": 2, "syslog": 1, "scanf": 0, } # Risk level colors RISK_HIGH = 0x0000ff # Red RISK_MEDIUM = 0x00a5ff # Orange RISK_LOW = 0x00ff00 # Green RISK_INFO = 0xffff00 # Cyan # ============================================================ # Enhanced Analysis - Data Structures # ============================================================ # Store function call information for cross-reference analysis taint_sources = {} # addr -> function_name (external input sources) free_calls = defaultdict(list) # func_addr -> [(call_addr, arg_info), ...] audit_results = [] # Store all findings for export wrapper_functions = {} # Detected wrapper functions data_flow_graph = defaultdict(list) # addr -> [(dest_addr, dest_func), ...] # Progress tracking total_functions = 0 processed_functions = 0 # ============================================================ # Configuration File Support # ============================================================ def get_output_dir(): """Get output directory - uses IDB directory in IDA environment""" try: idb_path = idc.get_idb_path() if idb_path: return os.path.dirname(idb_path) except: pass return os.getcwd() def load_config(): """Load configuration from external JSON file""" global dangerous_functions, attention_function, command_execution_function global external_input_functions, memory_alloc_functions, memory_free_functions global format_string_functions config_path = os.path.join(get_output_dir(), CONFIG_FILE) if os.path.exists(config_path): try: with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # Extend function lists from config if 'dangerous_functions' in config: dangerous_functions.extend(config['dangerous_functions']) if 'attention_function' in config: attention_function.extend(config['attention_function']) if 'command_execution_function' in config: command_execution_function.extend(config['command_execution_function']) if 'external_input_functions' in config: external_input_functions.extend(config['external_input_functions']) if 'memory_alloc_functions' in config: memory_alloc_functions.extend(config['memory_alloc_functions']) if 'memory_free_functions' in config: memory_free_functions.extend(config['memory_free_functions']) if 'format_string_functions' in config: format_string_functions.extend(config['format_string_functions']) print("[*] Loaded configuration from %s" % config_path) return True except Exception as e: print("[!] Error loading config: %s" % str(e)) return False def save_default_config(): """Save current configuration as default config file""" config_path = os.path.join(get_output_dir(), CONFIG_FILE) config = { "dangerous_functions": [], "attention_function": [], "command_execution_function": [], "external_input_functions": [], "memory_alloc_functions": [], "memory_free_functions": [], "format_string_functions": [], "_comment": "Add custom functions to extend the default lists" } try: with open(config_path, 'w', encoding='utf-8') as f: json.dump(config, f, indent=4) print("[*] Created default config file: %s" % config_path) except Exception as e: print("[!] Error saving config: %s" % str(e)) # ============================================================ # Progress Display # ============================================================ def show_progress(current, total, prefix="Progress"): """Display progress bar in IDA output""" if total == 0: return percent = (current * 100) // total bar_len = 30 filled = (current * bar_len) // total bar = '=' * filled + '-' * (bar_len - filled) print("\r%s: [%s] %d%% (%d/%d)" % (prefix, bar, percent, current, total), end='') if current == total: print() # New line when complete # ============================================================ # Helper Functions # ============================================================ def printFunc(func_name): string1 = "========================================" string2 = "========== Auditing " + func_name + " " strlen = len(string1) - len(string2) return string1 + "\n" + string2 + '=' * strlen + "\n" + string1 def getFuncAddr(func_name): func_addr = idc.get_name_ea_simple(func_name) if func_addr != BADADDR: print(printFunc(func_name)) return func_addr return False def getFormatString(addr): op_num = 1 # idc.get_operand_type Return value # o_void 0 // No Operand # o_reg 1 // General Register (al, ax, es, ds...) reg # o_mem 2 // Direct Memory Reference (DATA) addr # o_phrase 3 // Memory Ref [Base Reg + Index Reg] phrase # o_displ 4 // Memory Reg [Base Reg + Index Reg + Displacement] phrase+addr # o_imm 5 // Immediate Value value # o_far 6 // Immediate Far Address (CODE) addr # o_near 7 // Immediate Near Address (CODE) addr # o_idpspec0 8 // IDP specific type # o_idpspec1 9 // IDP specific type # o_idpspec2 10 // IDP specific type # o_idpspec3 11 // IDP specific type # o_idpspec4 12 // IDP specific type # o_idpspec5 13 // IDP specific type if idc.get_operand_type(addr, op_num) != 5: op_num = op_num + 1 if idc.get_operand_type(addr, op_num) != 5: return "get fail" op_string = idc.print_operand(addr, op_num).split(" ")[0].split("+")[0].split("-")[0].replace("(", "") string_addr = idc.get_name_ea_simple(op_string) if string_addr == BADADDR: return "get fail" string_content = idc.get_strlit_contents(string_addr) if string_content is None: return "get fail" if isinstance(string_content, bytes): string_content = string_content.decode('utf-8', errors='replace') return [string_addr, string_content] def getArgAddr(start_addr, regNum): mipscondition = ["bn", "be" , "bg", "bl"] scan_deep = 50 count = 0 reg = "$a" + str(regNum) # try to get in the next (code references from this address) for next_addr in idautils.CodeRefsFrom(start_addr, 0): if next_addr != BADADDR and reg == idc.print_operand(next_addr, 0): return next_addr # try to get before (code references to this address) before_addr = start_addr for ref_addr in idautils.CodeRefsTo(start_addr, 0): before_addr = ref_addr break if before_addr == start_addr: before_addr = idc.prev_head(start_addr) while before_addr != BADADDR: if reg == idc.print_operand(before_addr, 0): Mnemonics = idc.print_insn_mnem(before_addr) if Mnemonics[0:2] in mipscondition: pass elif Mnemonics[0:1] == "j": pass else: return before_addr count = count + 1 if count > scan_deep: break before_addr = idc.prev_head(before_addr) return BADADDR def getArg(start_addr, regNum): mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"] arg_addr = getArgAddr(start_addr, regNum) if arg_addr != BADADDR: Mnemonics = idc.print_insn_mnem(arg_addr) if Mnemonics[0:3] == "add": if idc.print_operand(arg_addr, 2) == "": arg = idc.print_operand(arg_addr, 0) + "+" + idc.print_operand(arg_addr, 1) else: arg = idc.print_operand(arg_addr, 1) + "+" + idc.print_operand(arg_addr, 2) elif Mnemonics[0:3] == "sub": if idc.print_operand(arg_addr, 2) == "": arg = idc.print_operand(arg_addr, 0) + "-" + idc.print_operand(arg_addr, 1) else: arg = idc.print_operand(arg_addr, 1) + "-" + idc.print_operand(arg_addr, 2) elif Mnemonics in mipsmov: arg = idc.print_operand(arg_addr, 1) else: arg = idc.GetDisasm(arg_addr).split("#")[0] idc.set_cmt(arg_addr, "addr: 0x%x " % start_addr + "-------> arg" + str((int(regNum)+1)) + " : " + arg, 0) return arg else: return "get fail" def audit(func_name): func_addr = getFuncAddr(func_name) if func_addr == False: return False # get arg num and set table if func_name in one_arg_function: arg_num = 1 elif func_name in two_arg_function: arg_num = 2 elif func_name in three_arg_function: arg_num = 3 elif func_name in format_function_offset_dict: arg_num = format_function_offset_dict[func_name] + 1 else: print("The %s function didn't write in the describe arg num of function array, please add it to, such as add to `two_arg_function` array" % func_name) return # mispcall = ["jal", "jalr", "bal", "jr"] table_head = ["func_name", "addr"] for num in range(0, arg_num): table_head.append("arg"+str(num+1)) if func_name in format_function_offset_dict: table_head.append("format&value[string_addr, num of '%', fmt_arg...]") table_head.append("local_buf_size") table = PrettyTable(table_head) # get references to function (xrefs) for call_addr in idautils.CodeRefsTo(func_addr, 0): # set color - green (red=0x0000ff, blue=0xff0000) idc.set_color(call_addr, idc.CIC_ITEM, 0x00ff00) # set break point # idc.add_bpt(call_addr) # idc.del_bpt(call_addr) Mnemonics = idc.print_insn_mnem(call_addr) if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b": if func_name in format_function_offset_dict: info = auditFormat(call_addr, func_name, arg_num) else: info = auditAddr(call_addr, func_name, arg_num) table.add_row(info) print(table) # data_addr = DfirstB(func_addr) # while data_addr != BADADDR: # Mnemonics = GetMnem(data_addr) # if DEBUG: # print "Data Mnemonics : %s" % GetMnem(data_addr) # print "Data addr : 0x %s" % data_addr # data_addr = DnextB(func_addr, data_addr) def auditAddr(call_addr, func_name, arg_num): addr = "0x%x" % call_addr ret_list = [func_name, addr] # local buf size local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE) if local_buf_size == BADADDR: local_buf_size = "get fail" else: local_buf_size = "0x%x" % local_buf_size # get arg for num in range(0, arg_num): ret_list.append(getArg(call_addr, num)) ret_list.append(local_buf_size) return ret_list def auditFormat(call_addr, func_name, arg_num): addr = "0x%x" % call_addr ret_list = [func_name, addr] # local buf size local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE) if local_buf_size == BADADDR: local_buf_size = "get fail" else: local_buf_size = "0x%x" % local_buf_size # get arg for num in range(0, arg_num): ret_list.append(getArg(call_addr, num)) arg_addr = getArgAddr(call_addr, format_function_offset_dict[func_name]) string_and_addr = getFormatString(arg_addr) format_and_value = [] if string_and_addr == "get fail": ret_list.append("get fail") else: string_addr = "0x%x" % string_and_addr[0] format_and_value.append(string_addr) string = string_and_addr[1] fmt_num = string.count("%") format_and_value.append(fmt_num) # mips arg reg is from a0 to a3 if fmt_num > 3: fmt_num = fmt_num - format_function_offset_dict[func_name] - 1 for num in range(0, fmt_num): if arg_num + num > 3: break format_and_value.append(getArg(call_addr, arg_num + num)) ret_list.append(format_and_value) ret_list.append(local_buf_size) return ret_list # ============================================================ # Enhanced Detection Functions # ============================================================ def getCallingFunction(addr): """Get the function containing the given address""" func = idaapi.get_func(addr) if func: return func.start_ea return BADADDR def traceArgSource(start_addr, regNum, depth=10): """ Trace the source of an argument to detect external input Returns: (source_type, source_info) source_type: 'external_input', 'static_string', 'stack_var', 'unknown' """ if depth <= 0: return ('unknown', None) mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"] arg_addr = getArgAddr(start_addr, regNum) if arg_addr == BADADDR: return ('unknown', None) Mnemonics = idc.print_insn_mnem(arg_addr) operand1 = idc.print_operand(arg_addr, 1) # Check if loading from a static string if Mnemonics in ["la", "lui", "li"]: str_addr = idc.get_name_ea_simple(operand1.split("+")[0].split("-")[0].replace("(", "")) if str_addr != BADADDR: str_content = idc.get_strlit_contents(str_addr) if str_content: if isinstance(str_content, bytes): str_content = str_content.decode('utf-8', errors='replace') return ('static_string', {'addr': str_addr, 'value': str_content, 'len': len(str_content)}) # Check if it's a stack variable if "sp" in operand1 or "fp" in operand1: return ('stack_var', {'operand': operand1}) # Check if moved from another register - trace further if Mnemonics == "move": src_reg = operand1 if src_reg.startswith("$v"): # Return value from function call # Look for preceding function call scan_addr = idc.prev_head(arg_addr) scan_count = 0 while scan_addr != BADADDR and scan_count < 20: mnem = idc.print_insn_mnem(scan_addr) if mnem in ["jal", "jalr"]: # Get the called function name call_target = idc.print_operand(scan_addr, 0) if call_target in external_input_functions: return ('external_input', {'function': call_target, 'addr': scan_addr}) break scan_count += 1 scan_addr = idc.prev_head(scan_addr) return ('unknown', None) def checkCommandInjection(call_addr, func_name): """ Check if command execution function has external input as argument Returns risk assessment """ results = [] # For system(), popen() - check first argument (command string) if func_name in ["system"]: source_type, source_info = traceArgSource(call_addr, 0, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': 'Command injection - arg from %s()' % source_info['function'], 'detail': source_info }) elif source_type == 'stack_var': results.append({ 'risk': 'MEDIUM', 'issue': 'Command from stack variable (needs manual review)', 'detail': source_info }) elif source_type == 'static_string': results.append({ 'risk': 'LOW', 'issue': 'Static command string', 'detail': source_info }) elif func_name in ["popen"]: source_type, source_info = traceArgSource(call_addr, 0, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': 'Popen injection - arg from %s()' % source_info['function'], 'detail': source_info }) elif func_name in ["execve", "execl", "execle", "execlp", "execv", "execvp"]: source_type, source_info = traceArgSource(call_addr, 0, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': 'Exec injection - arg from %s()' % source_info['function'], 'detail': source_info }) return results def checkStackOverflow(call_addr, func_name): """ Check for potential stack buffer overflow Compares destination buffer size with source length """ results = [] local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE) if func_name in ["strcpy", "strcat"]: # Check source (arg1) for static string length source_type, source_info = traceArgSource(call_addr, 1, depth=15) if source_type == 'static_string' and source_info: src_len = source_info.get('len', 0) if local_buf_size != BADADDR and src_len > 0: if src_len > local_buf_size: results.append({ 'risk': 'HIGH', 'issue': 'Buffer overflow: src_len(%d) > frame_size(0x%x)' % (src_len, local_buf_size), 'detail': source_info }) elif src_len > local_buf_size // 2: results.append({ 'risk': 'MEDIUM', 'issue': 'Potential overflow: src_len(%d) close to frame_size(0x%x)' % (src_len, local_buf_size), 'detail': source_info }) elif source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': '%s with external input from %s()' % (func_name, source_info['function']), 'detail': source_info }) elif func_name in ["sprintf"]: # Check format string for potential overflow arg_addr = getArgAddr(call_addr, 1) # format string is arg1 fmt_info = getFormatString(arg_addr) if fmt_info != "get fail": fmt_str = fmt_info[1] # Check for %s without width limit if '%s' in fmt_str: results.append({ 'risk': 'MEDIUM', 'issue': 'sprintf with %%s (unbounded string copy)', 'detail': {'format': fmt_str} }) elif func_name in ["memcpy", "strncpy", "strncat"]: # Check size parameter (arg2) source_type, source_info = traceArgSource(call_addr, 2, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': '%s size from external input %s()' % (func_name, source_info['function']), 'detail': source_info }) return results def checkIntegerOverflow(call_addr, func_name): """ Check for potential integer overflow in memory allocation """ results = [] if func_name in ["malloc", "calloc", "realloc", "memalign"]: # Check size argument size_arg_idx = 0 if func_name == "malloc" else 1 if func_name == "calloc": # calloc(count, size) - check both arguments for idx in [0, 1]: source_type, source_info = traceArgSource(call_addr, idx, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': 'calloc arg%d from external input %s()' % (idx, source_info['function']), 'detail': source_info }) else: source_type, source_info = traceArgSource(call_addr, size_arg_idx, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': '%s size from external input %s()' % (func_name, source_info['function']), 'detail': source_info }) # Check for arithmetic operations on size (potential integer overflow) arg_addr = getArgAddr(call_addr, size_arg_idx) if arg_addr != BADADDR: # Scan backward for arithmetic operations scan_addr = arg_addr scan_count = 0 while scan_addr != BADADDR and scan_count < 10: mnem = idc.print_insn_mnem(scan_addr) if mnem in ["mul", "mult", "sll", "add", "addu", "addi", "addiu"]: results.append({ 'risk': 'MEDIUM', 'issue': 'Arithmetic (%s) before %s - check for integer overflow' % (mnem, func_name), 'detail': {'addr': "0x%x" % scan_addr, 'instruction': idc.GetDisasm(scan_addr)} }) break scan_count += 1 scan_addr = idc.prev_head(scan_addr) return results def auditFreeCall(func_name): """ Audit free() calls for potential double-free vulnerabilities """ func_addr = idc.get_name_ea_simple(func_name) if func_addr == BADADDR: return print(printFunc(func_name + " (Double-Free Analysis)")) # Group by containing function calls_by_func = defaultdict(list) for call_addr in idautils.CodeRefsTo(func_addr, 0): Mnemonics = idc.print_insn_mnem(call_addr) if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b": containing_func = getCallingFunction(call_addr) arg_info = getArg(call_addr, 0) # First argument to free calls_by_func[containing_func].append({ 'addr': call_addr, 'arg': arg_info }) # Analyze each function for potential double-free table = PrettyTable(["function", "free_count", "addresses", "args", "risk"]) for func_start, calls in calls_by_func.items(): func_name_str = idc.get_func_name(func_start) or ("0x%x" % func_start) addrs = [("0x%x" % c['addr']) for c in calls] args = [c['arg'] for c in calls] # Risk assessment risk = "LOW" if len(calls) > 1: # Check if same argument pattern appears multiple times arg_counts = defaultdict(int) for arg in args: arg_counts[arg] += 1 for arg, count in arg_counts.items(): if count > 1 and arg != "get fail": risk = "HIGH" idc.set_color(calls[0]['addr'], idc.CIC_ITEM, RISK_HIGH) break else: if len(calls) > 2: risk = "MEDIUM" table.add_row([func_name_str, len(calls), ", ".join(addrs), ", ".join(args), risk]) print(table) def auditEnhanced(func_name): """ Enhanced audit with additional vulnerability checks """ global audit_results func_addr = idc.get_name_ea_simple(func_name) if func_addr == BADADDR: return False print(printFunc(func_name + " (Enhanced Analysis)")) # Determine check type based on function check_cmd_injection = func_name in command_execution_function check_overflow = func_name in dangerous_functions + attention_function check_int_overflow = func_name in memory_alloc_functions check_format_string = func_name in format_string_functions table = PrettyTable(["addr", "risk", "issue", "detail"]) for call_addr in idautils.CodeRefsTo(func_addr, 0): Mnemonics = idc.print_insn_mnem(call_addr) if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b": findings = [] if check_cmd_injection: findings.extend(checkCommandInjection(call_addr, func_name)) if check_overflow: findings.extend(checkStackOverflow(call_addr, func_name)) if check_int_overflow: findings.extend(checkIntegerOverflow(call_addr, func_name)) if check_format_string: findings.extend(checkFormatStringVuln(call_addr, func_name)) # Add findings to table and global results for finding in findings: risk = finding['risk'] # Set color based on risk if risk == 'HIGH': idc.set_color(call_addr, idc.CIC_ITEM, RISK_HIGH) elif risk == 'MEDIUM': idc.set_color(call_addr, idc.CIC_ITEM, RISK_MEDIUM) else: idc.set_color(call_addr, idc.CIC_ITEM, RISK_LOW) detail_str = str(finding.get('detail', ''))[:50] table.add_row(["0x%x" % call_addr, risk, finding['issue'], detail_str]) # Add to global results for export finding['address'] = "0x%x" % call_addr finding['function'] = func_name audit_results.append(finding) if table.rowcount > 0: print(table) else: print(" No enhanced findings for %s" % func_name) def collectTaintSources(): """ Collect all external input sources in the binary for taint analysis """ global taint_sources taint_sources = {} print("\n[*] Collecting external input sources...") for func_name in external_input_functions: func_addr = idc.get_name_ea_simple(func_name) if func_addr == BADADDR: continue for call_addr in idautils.CodeRefsTo(func_addr, 0): Mnemonics = idc.print_insn_mnem(call_addr) if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b": taint_sources[call_addr] = func_name print(" Found %d external input call sites" % len(taint_sources)) return taint_sources # ============================================================ # Basic Block Analysis (Fix cross-block issues) # ============================================================ def getBasicBlockBounds(addr): """Get the basic block boundaries containing the address""" func = idaapi.get_func(addr) if not func: return (BADADDR, BADADDR) try: flowchart = idaapi.FlowChart(func) for block in flowchart: if block.start_ea <= addr < block.end_ea: return (block.start_ea, block.end_ea) except: pass return (BADADDR, BADADDR) def getArgAddrInBlock(start_addr, regNum): """ Improved argument address detection within basic block boundaries Avoids crossing basic block boundaries which could lead to false positives """ mipscondition = ["bn", "be", "bg", "bl"] scan_deep = 50 count = 0 reg = "$a" + str(regNum) # Get basic block bounds block_start, block_end = getBasicBlockBounds(start_addr) # try to get in the next (code references from this address) for next_addr in idautils.CodeRefsFrom(start_addr, 0): if next_addr != BADADDR and reg == idc.print_operand(next_addr, 0): return next_addr # try to get before (within same basic block) before_addr = idc.prev_head(start_addr) while before_addr != BADADDR: # Stop if we cross basic block boundary if block_start != BADADDR and before_addr < block_start: break if reg == idc.print_operand(before_addr, 0): Mnemonics = idc.print_insn_mnem(before_addr) if Mnemonics[0:2] in mipscondition: pass elif Mnemonics[0:1] == "j": pass else: return before_addr count = count + 1 if count > scan_deep: break before_addr = idc.prev_head(before_addr) return BADADDR # ============================================================ # Format String Vulnerability Detection # ============================================================ def checkFormatStringVuln(call_addr, func_name): """ Enhanced format string vulnerability detection - Detects %n format specifier (write primitive) - Detects user-controlled format string """ results = [] if func_name not in format_function_offset_dict: return results fmt_arg_idx = format_function_offset_dict[func_name] arg_addr = getArgAddrInBlock(call_addr, fmt_arg_idx) if arg_addr == BADADDR: return results # Try to get the format string fmt_info = getFormatString(arg_addr) if fmt_info != "get fail": fmt_str = fmt_info[1] # Check for %n (write primitive - HIGH risk) if '%n' in fmt_str or '%hn' in fmt_str or '%hhn' in fmt_str or '%ln' in fmt_str: results.append({ 'risk': 'HIGH', 'issue': 'Format string with %%n write primitive', 'detail': {'format': fmt_str, 'addr': "0x%x" % fmt_info[0]}, 'type': 'format_string' }) # Check for multiple format specifiers (potential overflow) fmt_count = len(re.findall(r'%[^%]', fmt_str)) if fmt_count > 10: results.append({ 'risk': 'MEDIUM', 'issue': 'Format string with many specifiers (%d)' % fmt_count, 'detail': {'format': fmt_str[:50] + '...', 'count': fmt_count}, 'type': 'format_string' }) else: # Format string is not static - check if user controlled source_type, source_info = traceArgSource(call_addr, fmt_arg_idx, depth=15) if source_type == 'external_input': results.append({ 'risk': 'HIGH', 'issue': 'User-controlled format string from %s()' % source_info['function'], 'detail': source_info, 'type': 'format_string' }) elif source_type == 'stack_var': results.append({ 'risk': 'MEDIUM', 'issue': 'Format string from stack variable (needs review)', 'detail': source_info, 'type': 'format_string' }) return results # ============================================================ # Data Flow Analysis - Forward Tracking # ============================================================ def traceDataFlowForward(start_addr, src_func, max_depth=5): """ Trace where data from external input flows to Tracks return values from read/recv etc. to dangerous sinks """ results = [] if max_depth <= 0: return results # Get the function containing this call func = idaapi.get_func(start_addr) if not func: return results # For read/recv, the buffer is arg1 ($a1) # Track where this buffer is used buffer_arg_idx = 1 if src_func in ["read", "recv", "recvfrom", "fread"] else 0 buffer_operand = getArg(start_addr, buffer_arg_idx) if buffer_operand == "get fail": return results # Scan forward in the function to find uses of this buffer current_addr = idc.next_head(start_addr) func_end = func.end_ea scan_count = 0 max_scan = 100 while current_addr < func_end and current_addr != BADADDR and scan_count < max_scan: mnem = idc.print_insn_mnem(current_addr) # Check if this is a function call if mnem in ["jal", "jalr"]: call_target = idc.print_operand(current_addr, 0) # Check if any dangerous function is called with our tainted buffer all_dangerous = dangerous_functions + command_execution_function + format_string_functions if call_target in all_dangerous: # Check if our buffer is passed as an argument for arg_idx in range(4): # MIPS uses $a0-$a3 arg_operand = getArg(current_addr, arg_idx) if buffer_operand in arg_operand or arg_operand in buffer_operand: risk = 'HIGH' if call_target in command_execution_function else 'MEDIUM' results.append({ 'risk': risk, 'issue': 'Tainted data from %s() flows to %s()' % (src_func, call_target), 'detail': { 'source': "0x%x" % start_addr, 'sink': "0x%x" % current_addr, 'sink_func': call_target, 'buffer': buffer_operand }, 'type': 'data_flow' }) break current_addr = idc.next_head(current_addr) scan_count += 1 return results def analyzeDataFlow(): """ Perform data flow analysis from external inputs to dangerous sinks """ global data_flow_graph results = [] print("\n[*] Analyzing data flow from external inputs...") for call_addr, func_name in taint_sources.items(): if func_name in ["read", "recv", "recvfrom", "fread", "fgets", "gets"]: flow_results = traceDataFlowForward(call_addr, func_name) results.extend(flow_results) # Store in graph for visualization for r in flow_results: if 'detail' in r and 'sink' in r['detail']: data_flow_graph[call_addr].append((r['detail']['sink'], r['detail']['sink_func'])) print(" Found %d data flow issues" % len(results)) return results # ============================================================ # Wrapper Function Detection # ============================================================ def detectWrapperFunctions(): """ Detect wrapper functions that call dangerous functions e.g., my_strcpy that internally calls strcpy """ global wrapper_functions wrapper_functions = {} print("\n[*] Detecting wrapper functions...") all_dangerous = dangerous_functions + command_execution_function for dangerous_func in all_dangerous: func_addr = idc.get_name_ea_simple(dangerous_func) if func_addr == BADADDR: continue # Find all callers of this dangerous function for call_addr in idautils.CodeRefsTo(func_addr, 0): mnem = idc.print_insn_mnem(call_addr) if mnem[0:1] != "j" and mnem[0:1] != "b": continue # Get the function containing this call caller_func = idaapi.get_func(call_addr) if not caller_func: continue caller_name = idc.get_func_name(caller_func.start_ea) if not caller_name: continue # Heuristics for wrapper detection: # 1. Function name contains common wrapper patterns # 2. Function is small (likely just a wrapper) # 3. Function forwards arguments directly is_wrapper = False wrapper_type = None # Check name patterns wrapper_patterns = ['my_', 'safe_', 'wrap_', '_wrapper', '_safe', 'do_', 'internal_'] for pattern in wrapper_patterns: if pattern in caller_name.lower(): is_wrapper = True wrapper_type = 'name_match' break # Check function size (small functions are likely wrappers) func_size = caller_func.end_ea - caller_func.start_ea if func_size < 100 and not is_wrapper: # Less than ~25 instructions # Count how many other calls this function makes call_count = 0 for addr in idautils.FuncItems(caller_func.start_ea): m = idc.print_insn_mnem(addr) if m in ["jal", "jalr"]: call_count += 1 if call_count <= 2: # Mostly just calls the dangerous function is_wrapper = True wrapper_type = 'small_func' if is_wrapper: if caller_func.start_ea not in wrapper_functions: wrapper_functions[caller_func.start_ea] = { 'name': caller_name, 'wraps': [], 'type': wrapper_type } wrapper_functions[caller_func.start_ea]['wraps'].append(dangerous_func) print(" Found %d potential wrapper functions" % len(wrapper_functions)) return wrapper_functions def auditWrapperFunctions(): """ Audit detected wrapper functions """ if not wrapper_functions: detectWrapperFunctions() if not wrapper_functions: print(" No wrapper functions detected") return print(printFunc("Wrapper Functions")) table = PrettyTable(["wrapper_name", "address", "wraps", "detection_type"]) for func_addr, info in wrapper_functions.items(): idc.set_color(func_addr, idc.CIC_FUNC, RISK_MEDIUM) table.add_row([ info['name'], "0x%x" % func_addr, ", ".join(info['wraps']), info['type'] ]) print(table) # Also audit calls to wrapper functions print("\n Calls to wrapper functions:") wrapper_table = PrettyTable(["wrapper", "call_addr", "caller_func"]) for func_addr, info in wrapper_functions.items(): for call_addr in idautils.CodeRefsTo(func_addr, 0): mnem = idc.print_insn_mnem(call_addr) if mnem[0:1] == "j" or mnem[0:1] == "b": caller_func = idc.get_func_name(call_addr) idc.set_color(call_addr, idc.CIC_ITEM, RISK_MEDIUM) wrapper_table.add_row([info['name'], "0x%x" % call_addr, caller_func or "unknown"]) if wrapper_table.rowcount > 0: print(wrapper_table) # ============================================================ # Result Export Functions # ============================================================ def addFinding(finding): """Add a finding to the global results list""" global audit_results audit_results.append(finding) def exportResultsJSON(filename): """Export audit results to JSON file""" try: output_path = os.path.join(get_output_dir(), filename) with open(output_path, 'w', encoding='utf-8') as f: json.dump(audit_results, f, indent=2, ensure_ascii=False) print("[*] Results exported to: %s" % output_path) return True except Exception as e: print("[!] Error exporting JSON: %s" % str(e)) return False def exportResultsCSV(filename): """Export audit results to CSV file""" try: output_path = os.path.join(get_output_dir(), filename) with open(output_path, 'w', newline='', encoding='utf-8') as f: if audit_results: writer = csv.DictWriter(f, fieldnames=['risk', 'issue', 'type', 'address', 'function', 'detail']) writer.writeheader() for result in audit_results: row = { 'risk': result.get('risk', ''), 'issue': result.get('issue', ''), 'type': result.get('type', ''), 'address': result.get('address', ''), 'function': result.get('function', ''), 'detail': str(result.get('detail', ''))[:100] } writer.writerow(row) print("[*] Results exported to: %s" % output_path) return True except Exception as e: print("[!] Error exporting CSV: %s" % str(e)) return False def exportResultsHTML(filename): """Export audit results to HTML file""" try: output_path = os.path.join(get_output_dir(), filename) html_content = """ MIPS Audit Report

MIPS Security Audit Report

Total Findings: %d | HIGH: %d | MEDIUM: %d | LOW: %d
""" % ( len(audit_results), len([r for r in audit_results if r.get('risk') == 'HIGH']), len([r for r in audit_results if r.get('risk') == 'MEDIUM']), len([r for r in audit_results if r.get('risk') == 'LOW']) ) for result in audit_results: risk = result.get('risk', '') html_content += """ """ % ( risk, risk, result.get('issue', ''), result.get('type', ''), result.get('address', ''), result.get('function', ''), str(result.get('detail', ''))[:100] ) html_content += """
Risk Issue Type Address Function Detail
%s %s %s %s %s %s
""" with open(output_path, 'w', encoding='utf-8') as f: f.write(html_content) print("[*] Results exported to: %s" % output_path) return True except Exception as e: print("[!] Error exporting HTML: %s" % str(e)) return False def exportResults(base_filename="mipsAudit_results"): """Export results as HTML with timestamp""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = "%s_%s.html" % (base_filename, timestamp) exportResultsHTML(filename) return filename def mipsAudit(): """Main audit function with all enhanced features""" global audit_results, total_functions, processed_functions # Reset global state audit_results = [] # the word create with figlet start = ''' _ _ _ _ _ _ __ ___ (_)_ __ ___ / \ _ _ __| (_) |_ | '_ ` _ \| | '_ \/ __| / _ \| | | |/ _` | | __| | | | | | | | |_) \__ \/ ___ \ |_| | (_| | | |_ |_| |_| |_|_| .__/|___/_/ \_\__,_|\__,_|_|\__| |_| code by giantbranch 2018.05 updated for IDA 7.x+ & Python 3 enhanced detection v3.0 20260129 ''' print(start) # Load external configuration if available load_config() # Calculate total functions for progress display all_functions = (dangerous_functions + attention_function + command_execution_function + memory_alloc_functions + memory_free_functions + format_string_functions) total_functions = len(set(all_functions)) processed_functions = 0 # Collect taint sources first for enhanced analysis collectTaintSources() print("\n" + "=" * 60) print(" PHASE 1: Basic Function Audit") print("=" * 60) print("\nAuditing dangerous functions ......") for i, func_name in enumerate(dangerous_functions): audit(func_name) show_progress(i + 1, len(dangerous_functions), "Dangerous funcs") print("\nAuditing attention function ......") for i, func_name in enumerate(attention_function): audit(func_name) show_progress(i + 1, len(attention_function), "Attention funcs") print("\nAuditing command execution function ......") for i, func_name in enumerate(command_execution_function): audit(func_name) show_progress(i + 1, len(command_execution_function), "Cmd exec funcs") print("\n" + "=" * 60) print(" PHASE 2: Enhanced Vulnerability Detection") print("=" * 60) print("\n[*] Enhanced analysis: Command Injection Detection") for func_name in command_execution_function: auditEnhanced(func_name) print("\n[*] Enhanced analysis: Stack Overflow Detection") for func_name in dangerous_functions: auditEnhanced(func_name) print("\n[*] Enhanced analysis: Integer Overflow Detection") for func_name in memory_alloc_functions: auditEnhanced(func_name) print("\n[*] Enhanced analysis: Format String Vulnerability Detection") for func_name in format_string_functions: auditEnhanced(func_name) print("\n[*] Enhanced analysis: Double-Free Detection") for func_name in memory_free_functions: auditFreeCall(func_name) print("\n" + "=" * 60) print(" PHASE 3: Advanced Analysis") print("=" * 60) # Data flow analysis flow_results = analyzeDataFlow() if flow_results: print("\nData Flow Analysis Results:") table = PrettyTable(["risk", "issue", "source", "sink"]) for r in flow_results: table.add_row([ r['risk'], r['issue'], r['detail'].get('source', ''), r['detail'].get('sink', '') ]) audit_results.append(r) print(table) # Wrapper function detection print("\n[*] Detecting and auditing wrapper functions...") auditWrapperFunctions() print("\n" + "=" * 60) print(" PHASE 4: Results Summary & Export") print("=" * 60) # Summary high_count = len([r for r in audit_results if r.get('risk') == 'HIGH']) medium_count = len([r for r in audit_results if r.get('risk') == 'MEDIUM']) low_count = len([r for r in audit_results if r.get('risk') == 'LOW']) print("\n[*] Audit Summary:") print(" Total findings: %d" % len(audit_results)) print(" HIGH risk: %d" % high_count) print(" MEDIUM risk: %d" % medium_count) print(" LOW risk: %d" % low_count) # Export results export_file = None if audit_results: print("\n[*] Exporting results...") export_file = exportResults() print("\nExported: %s" % export_file) print("\n" + "=" * 60) print(" Finished! Enjoy the result ~") print("=" * 60) # Check processor architecture (may be useful in future) # info = idaapi.get_inf_structure() # # if info.is_64bit(): # bits = 64 # elif info.is_32bit(): # bits = 32 # else: # bits = 16 # # try: # is_be = info.is_be() # except: # is_be = info.mf # endian = "big" if is_be else "little" # # print('Processor: {}, {}bit, {} endian'.format(info.procName, bits, endian)) # # Result: Processor: mipsr, 32bit, big endian mipsAudit() ================================================ FILE: oldversion(py2)/mipsAudit.py ================================================ # -*- coding: utf-8 -*- # reference # 《ida pro 权威指南》 # 《python 灰帽子》 # 《家用路由器0day漏洞挖掘》 # https://github.com/wangzery/SearchOverflow/blob/master/SearchOverflow.py from idaapi import * from prettytable import PrettyTable DEBUG = True # fgetc,fgets,fread,fprintf, # vspritnf # set function_name dangerous_functions = [ "strcpy", "strcat", "sprintf", "read", "getenv" ] attention_function = [ "memcpy", "strncpy", "sscanf", "strncat", "snprintf", "vprintf", "printf" ] command_execution_function = [ "system", "execve", "popen", "unlink" ] # describe arg num of function one_arg_function = [ "getenv", "system", "unlink" ] two_arg_function = [ "strcpy", "strcat", "popen" ] three_arg_function = [ "strncpy", "strncat", "memcpy", "execve", "read" ] format_function_offset_dict = { "sprintf":1, "sscanf":1, "snprintf":2, "vprintf":0, "printf":0 } def printFunc(func_name): string1 = "========================================" string2 = "========== Aduiting " + func_name + " " strlen = len(string1) - len(string2) return string1 + "\n" + string2 + '=' * strlen + "\n" + string1 def getFuncAddr(func_name): func_addr = LocByName(func_name) if func_addr != BADADDR: print printFunc(func_name) # print func_name + " Addr : 0x %x" % func_addr return func_addr return False def getFormatString(addr): op_num = 1 # GetOpType Return value #define o_void 0 // No Operand ---------- #define o_reg 1 // General Register (al, ax, es, ds...) reg #define o_mem 2 // Direct Memory Reference (DATA) addr #define o_phrase 3 // Memory Ref [Base Reg + Index Reg] phrase #define o_displ 4 // Memory Reg [Base Reg + Index Reg + Displacement] phrase+addr #define o_imm 5 // Immediate Value value #define o_far 6 // Immediate Far Address (CODE) addr #define o_near 7 // Immediate Near Address (CODE) addr #define o_idpspec0 8 // IDP specific type #define o_idpspec1 9 // IDP specific type #define o_idpspec2 10 // IDP specific type #define o_idpspec3 11 // IDP specific type #define o_idpspec4 12 // IDP specific type #define o_idpspec5 13 // IDP specific type # 如果第二个不是立即数则下一个 if(GetOpType(addr ,op_num) != 5): op_num = op_num + 1 if GetOpType(addr ,op_num) != 5: return "get fail" op_string = GetOpnd(addr, op_num).split(" ")[0].split("+")[0].split("-")[0].replace("(", "") string_addr = LocByName(op_string) if string_addr == BADADDR: return "get fail" string = str(GetString(string_addr)) return [string_addr, string] def getArgAddr(start_addr, regNum): mipscondition = ["bn", "be" , "bg", "bl"] scan_deep = 50 count = 0 reg = "$a" + str(regNum) # try to get in the next next_addr = Rfirst(start_addr) if next_addr != BADADDR and reg == GetOpnd(next_addr, 0): return next_addr # try to get before before_addr = RfirstB(start_addr) while before_addr != BADADDR: if reg == GetOpnd(before_addr, 0): Mnemonics = GetMnem(before_addr) if Mnemonics[0:2] in mipscondition: pass elif Mnemonics[0:1] == "j": pass else: return before_addr count = count + 1 if count > scan_deep: break before_addr = RfirstB(before_addr) return BADADDR def getArg(start_addr, regNum): mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"] arg_addr = getArgAddr(start_addr, regNum) if arg_addr != BADADDR: Mnemonics = GetMnem(arg_addr) if Mnemonics[0:3] == "add": if GetOpnd(arg_addr, 2) == "": arg = GetOpnd(arg_addr, 0) + "+" + GetOpnd(arg_addr, 1) else: arg = GetOpnd(arg_addr, 1) + "+" + GetOpnd(arg_addr, 2) elif Mnemonics[0:3] == "sub": if GetOpnd(arg_addr, 2) == "": arg = GetOpnd(arg_addr, 0) + "-" + GetOpnd(arg_addr, 1) else: arg = GetOpnd(arg_addr, 1) + "-" + GetOpnd(arg_addr, 2) elif Mnemonics in mipsmov: arg = GetOpnd(arg_addr, 1) else: arg = GetDisasm(arg_addr).split("#")[0] MakeComm(arg_addr, "addr: 0x%x " % start_addr + "-------> arg" + str((int(regNum)+1)) + " : " + arg) return arg else: return "get fail" def audit(func_name): func_addr = getFuncAddr(func_name) if func_addr == False: return False # get arg num and set table if func_name in one_arg_function: arg_num = 1 elif func_name in two_arg_function: arg_num = 2 elif func_name in three_arg_function: arg_num = 3 elif func_name in format_function_offset_dict: arg_num = format_function_offset_dict[func_name] + 1 else: print "The %s function didn't write in the describe arg num of function array,please add it to,such as add to `two_arg_function` arary" % func_name return # mispcall = ["jal", "jalr", "bal", "jr"] table_head = ["func_name", "addr"] for num in xrange(0,arg_num): table_head.append("arg"+str(num+1)) if func_name in format_function_offset_dict: table_head.append("format&value[string_addr, num of '%', fmt_arg...]") table_head.append("local_buf_size") table = PrettyTable(table_head) # get first call call_addr = RfirstB(func_addr) while call_addr != BADADDR: # set color ———— green (red=0x0000ff,blue = 0xff0000) SetColor(call_addr, CIC_ITEM, 0x00ff00) # set break point # AddBpt(call_addr) # DelBpt(call_addr) # if you want to use condition # SetBptCnd(ea, 'strstr(GetString(Dword(esp+4),-1, 0), "SAEXT.DLL") != -1') Mnemonics = GetMnem(call_addr) # print "Mnemonics : %s" % Mnemonics # if Mnemonics in mispcall: if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b": # print func + " addr : 0x%x" % call_addr if func_name in format_function_offset_dict: info = auditFormat(call_addr, func_name, arg_num) else: info = auditAddr(call_addr, func_name, arg_num) table.add_row(info) call_addr = RnextB(func_addr, call_addr) print table # data_addr = DfirstB(func_addr) # while data_addr != BADADDR: # Mnemonics = GetMnem(data_addr) # if DEBUG: # print "Data Mnemonics : %s" % GetMnem(data_addr) # print "Data addr : 0x %s" % data_addr # data_addr = DnextB(func_addr, data_addr) def auditAddr(call_addr, func_name, arg_num): addr = "0x%x" % call_addr ret_list = [func_name, addr] # local buf size local_buf_size = GetFunctionAttr(call_addr , FUNCATTR_FRSIZE) if local_buf_size == BADADDR : local_buf_size = "get fail" else: local_buf_size = "0x%x" % local_buf_size # get arg for num in xrange(0,arg_num): ret_list.append(getArg(call_addr, num)) ret_list.append(local_buf_size) return ret_list def auditFormat(call_addr, func_name, arg_num): addr = "0x%x" % call_addr ret_list = [func_name, addr] # local buf size local_buf_size = GetFunctionAttr(call_addr , FUNCATTR_FRSIZE) if local_buf_size == BADADDR : local_buf_size = "get fail" else: local_buf_size = "0x%x" % local_buf_size # get arg for num in xrange(0,arg_num): ret_list.append(getArg(call_addr, num)) arg_addr = getArgAddr(call_addr, format_function_offset_dict[func_name]) string_and_addr = getFormatString(arg_addr) format_and_value = [] if string_and_addr == "get fail": ret_list.append("get fail") else: string_addr = "0x%x" % string_and_addr[0] format_and_value.append(string_addr) string = string_and_addr[1] fmt_num = string.count("%") format_and_value.append(fmt_num) # mips arg reg is from a0 to a3 if fmt_num > 3: fmt_num = fmt_num - format_function_offset_dict[func_name] - 1 for num in xrange(0,fmt_num): if arg_num + num > 3: break format_and_value.append(getArg(call_addr, arg_num + num)) ret_list.append(format_and_value) # format_string = str(getFormatString(arg_addr)[1]) # print " format String: " + format_string # ret_list.append([string_addr]) ret_list.append(local_buf_size) return ret_list def mipsAudit(): # the word create with figlet start = ''' _ _ _ _ _ _ __ ___ (_)_ __ ___ / \ _ _ __| (_) |_ | '_ ` _ \| | '_ \/ __| / _ \| | | |/ _` | | __| | | | | | | | |_) \__ \/ ___ \ |_| | (_| | | |_ |_| |_| |_|_| .__/|___/_/ \_\__,_|\__,_|_|\__| |_| code by giantbranch 2018.05 ''' print start print "Auditing dangerous functions ......" for func_name in dangerous_functions: audit(func_name) print "Auditing attention function ......" for func_name in attention_function: audit(func_name) print "Auditing command execution function ......" for func_name in command_execution_function: audit(func_name) print "Finished! Enjoy the result ~" # 判断架构的代码,以后或许用得上 # info = idaapi.get_inf_structure() # if info.is_64bit(): # bits = 64 # elif info.is_32bit(): # bits = 32 # else: # bits = 16 # try: # is_be = info.is_be() # except: # is_be = info.mf # endian = "big" if is_be else "little" # print 'Processor: {}, {}bit, {} endian'.format(info.procName, bits, endian) # # Result: Processor: mipsr, 32bit, big endian mipsAudit() ================================================ FILE: oldversion(py2)/prettytable.py ================================================ #!/usr/bin/env python # # Copyright (c) 2009-2013, Luke Maurits # All rights reserved. # With contributions from: # * Chris Clark # * Klein Stephane # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # * The name of the author may not be used to endorse or promote products # derived from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. __version__ = "0.7.2" import copy import csv import random import re import sys import textwrap import itertools import unicodedata py3k = sys.version_info[0] >= 3 if py3k: unicode = str basestring = str itermap = map iterzip = zip uni_chr = chr from html.parser import HTMLParser else: itermap = itertools.imap iterzip = itertools.izip uni_chr = unichr from HTMLParser import HTMLParser if py3k and sys.version_info[1] >= 2: from html import escape else: from cgi import escape # hrule styles FRAME = 0 ALL = 1 NONE = 2 HEADER = 3 # Table styles DEFAULT = 10 MSWORD_FRIENDLY = 11 PLAIN_COLUMNS = 12 RANDOM = 20 _re = re.compile("\033\[[0-9;]*m") def _get_size(text): lines = text.split("\n") height = len(lines) width = max([_str_block_width(line) for line in lines]) return (width, height) class PrettyTable(object): def __init__(self, field_names=None, **kwargs): """Return a new PrettyTable instance Arguments: encoding - Unicode encoding scheme used to decode any encoded input field_names - list or tuple of field names fields - list or tuple of field names to include in displays start - index of first data row to include in output end - index of last data row to include in output PLUS ONE (list slice style) header - print a header showing field names (True or False) header_style - stylisation to apply to field names in header ("cap", "title", "upper", "lower" or None) border - print a border around the table (True or False) hrules - controls printing of horizontal rules after rows. Allowed values: FRAME, HEADER, ALL, NONE vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE int_format - controls formatting of integer data float_format - controls formatting of floating point data padding_width - number of spaces on either side of column data (only used if left and right paddings are None) left_padding_width - number of spaces on left hand side of column data right_padding_width - number of spaces on right hand side of column data vertical_char - single character string used to draw vertical lines horizontal_char - single character string used to draw horizontal lines junction_char - single character string used to draw line junctions sortby - name of field to sort rows by sort_key - sorting key function, applied to data points before sorting valign - default valign for each row (None, "t", "m" or "b") reversesort - True or False to sort in descending or ascending order""" self.encoding = kwargs.get("encoding", "UTF-8") # Data self._field_names = [] self._align = {} self._valign = {} self._max_width = {} self._rows = [] if field_names: self.field_names = field_names else: self._widths = [] # Options self._options = "start end fields header border sortby reversesort sort_key attributes format hrules vrules".split() self._options.extend("int_format float_format padding_width left_padding_width right_padding_width".split()) self._options.extend("vertical_char horizontal_char junction_char header_style valign xhtml print_empty".split()) for option in self._options: if option in kwargs: self._validate_option(option, kwargs[option]) else: kwargs[option] = None self._start = kwargs["start"] or 0 self._end = kwargs["end"] or None self._fields = kwargs["fields"] or None if kwargs["header"] in (True, False): self._header = kwargs["header"] else: self._header = True self._header_style = kwargs["header_style"] or None if kwargs["border"] in (True, False): self._border = kwargs["border"] else: self._border = True self._hrules = kwargs["hrules"] or FRAME self._vrules = kwargs["vrules"] or ALL self._sortby = kwargs["sortby"] or None if kwargs["reversesort"] in (True, False): self._reversesort = kwargs["reversesort"] else: self._reversesort = False self._sort_key = kwargs["sort_key"] or (lambda x: x) self._int_format = kwargs["int_format"] or {} self._float_format = kwargs["float_format"] or {} self._padding_width = kwargs["padding_width"] or 1 self._left_padding_width = kwargs["left_padding_width"] or None self._right_padding_width = kwargs["right_padding_width"] or None self._vertical_char = kwargs["vertical_char"] or self._unicode("|") self._horizontal_char = kwargs["horizontal_char"] or self._unicode("-") self._junction_char = kwargs["junction_char"] or self._unicode("+") if kwargs["print_empty"] in (True, False): self._print_empty = kwargs["print_empty"] else: self._print_empty = True self._format = kwargs["format"] or False self._xhtml = kwargs["xhtml"] or False self._attributes = kwargs["attributes"] or {} def _unicode(self, value): if not isinstance(value, basestring): value = str(value) if not isinstance(value, unicode): value = unicode(value, self.encoding, "strict") return value def _justify(self, text, width, align): excess = width - _str_block_width(text) if align == "l": return text + excess * " " elif align == "r": return excess * " " + text else: if excess % 2: # Uneven padding # Put more space on right if text is of odd length... if _str_block_width(text) % 2: return (excess//2)*" " + text + (excess//2 + 1)*" " # and more space on left if text is of even length else: return (excess//2 + 1)*" " + text + (excess//2)*" " # Why distribute extra space this way? To match the behaviour of # the inbuilt str.center() method. else: # Equal padding on either side return (excess//2)*" " + text + (excess//2)*" " def __getattr__(self, name): if name == "rowcount": return len(self._rows) elif name == "colcount": if self._field_names: return len(self._field_names) elif self._rows: return len(self._rows[0]) else: return 0 else: raise AttributeError(name) def __getitem__(self, index): new = PrettyTable() new.field_names = self.field_names for attr in self._options: setattr(new, "_"+attr, getattr(self, "_"+attr)) setattr(new, "_align", getattr(self, "_align")) if isinstance(index, slice): for row in self._rows[index]: new.add_row(row) elif isinstance(index, int): new.add_row(self._rows[index]) else: raise Exception("Index %s is invalid, must be an integer or slice" % str(index)) return new if py3k: def __str__(self): return self.__unicode__() else: def __str__(self): return self.__unicode__().encode(self.encoding) def __unicode__(self): return self.get_string() ############################## # ATTRIBUTE VALIDATORS # ############################## # The method _validate_option is all that should be used elsewhere in the code base to validate options. # It will call the appropriate validation method for that option. The individual validation methods should # never need to be called directly (although nothing bad will happen if they *are*). # Validation happens in TWO places. # Firstly, in the property setters defined in the ATTRIBUTE MANAGMENT section. # Secondly, in the _get_options method, where keyword arguments are mixed with persistent settings def _validate_option(self, option, val): if option in ("field_names"): self._validate_field_names(val) elif option in ("start", "end", "max_width", "padding_width", "left_padding_width", "right_padding_width", "format"): self._validate_nonnegative_int(option, val) elif option in ("sortby"): self._validate_field_name(option, val) elif option in ("sort_key"): self._validate_function(option, val) elif option in ("hrules"): self._validate_hrules(option, val) elif option in ("vrules"): self._validate_vrules(option, val) elif option in ("fields"): self._validate_all_field_names(option, val) elif option in ("header", "border", "reversesort", "xhtml", "print_empty"): self._validate_true_or_false(option, val) elif option in ("header_style"): self._validate_header_style(val) elif option in ("int_format"): self._validate_int_format(option, val) elif option in ("float_format"): self._validate_float_format(option, val) elif option in ("vertical_char", "horizontal_char", "junction_char"): self._validate_single_char(option, val) elif option in ("attributes"): self._validate_attributes(option, val) else: raise Exception("Unrecognised option: %s!" % option) def _validate_field_names(self, val): # Check for appropriate length if self._field_names: try: assert len(val) == len(self._field_names) except AssertionError: raise Exception("Field name list has incorrect number of values, (actual) %d!=%d (expected)" % (len(val), len(self._field_names))) if self._rows: try: assert len(val) == len(self._rows[0]) except AssertionError: raise Exception("Field name list has incorrect number of values, (actual) %d!=%d (expected)" % (len(val), len(self._rows[0]))) # Check for uniqueness try: assert len(val) == len(set(val)) except AssertionError: raise Exception("Field names must be unique!") def _validate_header_style(self, val): try: assert val in ("cap", "title", "upper", "lower", None) except AssertionError: raise Exception("Invalid header style, use cap, title, upper, lower or None!") def _validate_align(self, val): try: assert val in ["l","c","r"] except AssertionError: raise Exception("Alignment %s is invalid, use l, c or r!" % val) def _validate_valign(self, val): try: assert val in ["t","m","b",None] except AssertionError: raise Exception("Alignment %s is invalid, use t, m, b or None!" % val) def _validate_nonnegative_int(self, name, val): try: assert int(val) >= 0 except AssertionError: raise Exception("Invalid value for %s: %s!" % (name, self._unicode(val))) def _validate_true_or_false(self, name, val): try: assert val in (True, False) except AssertionError: raise Exception("Invalid value for %s! Must be True or False." % name) def _validate_int_format(self, name, val): if val == "": return try: assert type(val) in (str, unicode) assert val.isdigit() except AssertionError: raise Exception("Invalid value for %s! Must be an integer format string." % name) def _validate_float_format(self, name, val): if val == "": return try: assert type(val) in (str, unicode) assert "." in val bits = val.split(".") assert len(bits) <= 2 assert bits[0] == "" or bits[0].isdigit() assert bits[1] == "" or bits[1].isdigit() except AssertionError: raise Exception("Invalid value for %s! Must be a float format string." % name) def _validate_function(self, name, val): try: assert hasattr(val, "__call__") except AssertionError: raise Exception("Invalid value for %s! Must be a function." % name) def _validate_hrules(self, name, val): try: assert val in (ALL, FRAME, HEADER, NONE) except AssertionError: raise Exception("Invalid value for %s! Must be ALL, FRAME, HEADER or NONE." % name) def _validate_vrules(self, name, val): try: assert val in (ALL, FRAME, NONE) except AssertionError: raise Exception("Invalid value for %s! Must be ALL, FRAME, or NONE." % name) def _validate_field_name(self, name, val): try: assert (val in self._field_names) or (val is None) except AssertionError: raise Exception("Invalid field name: %s!" % val) def _validate_all_field_names(self, name, val): try: for x in val: self._validate_field_name(name, x) except AssertionError: raise Exception("fields must be a sequence of field names!") def _validate_single_char(self, name, val): try: assert _str_block_width(val) == 1 except AssertionError: raise Exception("Invalid value for %s! Must be a string of length 1." % name) def _validate_attributes(self, name, val): try: assert isinstance(val, dict) except AssertionError: raise Exception("attributes must be a dictionary of name/value pairs!") ############################## # ATTRIBUTE MANAGEMENT # ############################## def _get_field_names(self): return self._field_names """The names of the fields Arguments: fields - list or tuple of field names""" def _set_field_names(self, val): val = [self._unicode(x) for x in val] self._validate_option("field_names", val) if self._field_names: old_names = self._field_names[:] self._field_names = val if self._align and old_names: for old_name, new_name in zip(old_names, val): self._align[new_name] = self._align[old_name] for old_name in old_names: if old_name not in self._align: self._align.pop(old_name) else: for field in self._field_names: self._align[field] = "c" if self._valign and old_names: for old_name, new_name in zip(old_names, val): self._valign[new_name] = self._valign[old_name] for old_name in old_names: if old_name not in self._valign: self._valign.pop(old_name) else: for field in self._field_names: self._valign[field] = "t" field_names = property(_get_field_names, _set_field_names) def _get_align(self): return self._align def _set_align(self, val): self._validate_align(val) for field in self._field_names: self._align[field] = val align = property(_get_align, _set_align) def _get_valign(self): return self._valign def _set_valign(self, val): self._validate_valign(val) for field in self._field_names: self._valign[field] = val valign = property(_get_valign, _set_valign) def _get_max_width(self): return self._max_width def _set_max_width(self, val): self._validate_option("max_width", val) for field in self._field_names: self._max_width[field] = val max_width = property(_get_max_width, _set_max_width) def _get_fields(self): """List or tuple of field names to include in displays Arguments: fields - list or tuple of field names to include in displays""" return self._fields def _set_fields(self, val): self._validate_option("fields", val) self._fields = val fields = property(_get_fields, _set_fields) def _get_start(self): """Start index of the range of rows to print Arguments: start - index of first data row to include in output""" return self._start def _set_start(self, val): self._validate_option("start", val) self._start = val start = property(_get_start, _set_start) def _get_end(self): """End index of the range of rows to print Arguments: end - index of last data row to include in output PLUS ONE (list slice style)""" return self._end def _set_end(self, val): self._validate_option("end", val) self._end = val end = property(_get_end, _set_end) def _get_sortby(self): """Name of field by which to sort rows Arguments: sortby - field name to sort by""" return self._sortby def _set_sortby(self, val): self._validate_option("sortby", val) self._sortby = val sortby = property(_get_sortby, _set_sortby) def _get_reversesort(self): """Controls direction of sorting (ascending vs descending) Arguments: reveresort - set to True to sort by descending order, or False to sort by ascending order""" return self._reversesort def _set_reversesort(self, val): self._validate_option("reversesort", val) self._reversesort = val reversesort = property(_get_reversesort, _set_reversesort) def _get_sort_key(self): """Sorting key function, applied to data points before sorting Arguments: sort_key - a function which takes one argument and returns something to be sorted""" return self._sort_key def _set_sort_key(self, val): self._validate_option("sort_key", val) self._sort_key = val sort_key = property(_get_sort_key, _set_sort_key) def _get_header(self): """Controls printing of table header with field names Arguments: header - print a header showing field names (True or False)""" return self._header def _set_header(self, val): self._validate_option("header", val) self._header = val header = property(_get_header, _set_header) def _get_header_style(self): """Controls stylisation applied to field names in header Arguments: header_style - stylisation to apply to field names in header ("cap", "title", "upper", "lower" or None)""" return self._header_style def _set_header_style(self, val): self._validate_header_style(val) self._header_style = val header_style = property(_get_header_style, _set_header_style) def _get_border(self): """Controls printing of border around table Arguments: border - print a border around the table (True or False)""" return self._border def _set_border(self, val): self._validate_option("border", val) self._border = val border = property(_get_border, _set_border) def _get_hrules(self): """Controls printing of horizontal rules after rows Arguments: hrules - horizontal rules style. Allowed values: FRAME, ALL, HEADER, NONE""" return self._hrules def _set_hrules(self, val): self._validate_option("hrules", val) self._hrules = val hrules = property(_get_hrules, _set_hrules) def _get_vrules(self): """Controls printing of vertical rules between columns Arguments: vrules - vertical rules style. Allowed values: FRAME, ALL, NONE""" return self._vrules def _set_vrules(self, val): self._validate_option("vrules", val) self._vrules = val vrules = property(_get_vrules, _set_vrules) def _get_int_format(self): """Controls formatting of integer data Arguments: int_format - integer format string""" return self._int_format def _set_int_format(self, val): # self._validate_option("int_format", val) for field in self._field_names: self._int_format[field] = val int_format = property(_get_int_format, _set_int_format) def _get_float_format(self): """Controls formatting of floating point data Arguments: float_format - floating point format string""" return self._float_format def _set_float_format(self, val): # self._validate_option("float_format", val) for field in self._field_names: self._float_format[field] = val float_format = property(_get_float_format, _set_float_format) def _get_padding_width(self): """The number of empty spaces between a column's edge and its content Arguments: padding_width - number of spaces, must be a positive integer""" return self._padding_width def _set_padding_width(self, val): self._validate_option("padding_width", val) self._padding_width = val padding_width = property(_get_padding_width, _set_padding_width) def _get_left_padding_width(self): """The number of empty spaces between a column's left edge and its content Arguments: left_padding - number of spaces, must be a positive integer""" return self._left_padding_width def _set_left_padding_width(self, val): self._validate_option("left_padding_width", val) self._left_padding_width = val left_padding_width = property(_get_left_padding_width, _set_left_padding_width) def _get_right_padding_width(self): """The number of empty spaces between a column's right edge and its content Arguments: right_padding - number of spaces, must be a positive integer""" return self._right_padding_width def _set_right_padding_width(self, val): self._validate_option("right_padding_width", val) self._right_padding_width = val right_padding_width = property(_get_right_padding_width, _set_right_padding_width) def _get_vertical_char(self): """The charcter used when printing table borders to draw vertical lines Arguments: vertical_char - single character string used to draw vertical lines""" return self._vertical_char def _set_vertical_char(self, val): val = self._unicode(val) self._validate_option("vertical_char", val) self._vertical_char = val vertical_char = property(_get_vertical_char, _set_vertical_char) def _get_horizontal_char(self): """The charcter used when printing table borders to draw horizontal lines Arguments: horizontal_char - single character string used to draw horizontal lines""" return self._horizontal_char def _set_horizontal_char(self, val): val = self._unicode(val) self._validate_option("horizontal_char", val) self._horizontal_char = val horizontal_char = property(_get_horizontal_char, _set_horizontal_char) def _get_junction_char(self): """The charcter used when printing table borders to draw line junctions Arguments: junction_char - single character string used to draw line junctions""" return self._junction_char def _set_junction_char(self, val): val = self._unicode(val) self._validate_option("vertical_char", val) self._junction_char = val junction_char = property(_get_junction_char, _set_junction_char) def _get_format(self): """Controls whether or not HTML tables are formatted to match styling options Arguments: format - True or False""" return self._format def _set_format(self, val): self._validate_option("format", val) self._format = val format = property(_get_format, _set_format) def _get_print_empty(self): """Controls whether or not empty tables produce a header and frame or just an empty string Arguments: print_empty - True or False""" return self._print_empty def _set_print_empty(self, val): self._validate_option("print_empty", val) self._print_empty = val print_empty = property(_get_print_empty, _set_print_empty) def _get_attributes(self): """A dictionary of HTML attribute name/value pairs to be included in the tag when printing HTML Arguments: attributes - dictionary of attributes""" return self._attributes def _set_attributes(self, val): self._validate_option("attributes", val) self._attributes = val attributes = property(_get_attributes, _set_attributes) ############################## # OPTION MIXER # ############################## def _get_options(self, kwargs): options = {} for option in self._options: if option in kwargs: self._validate_option(option, kwargs[option]) options[option] = kwargs[option] else: options[option] = getattr(self, "_"+option) return options ############################## # PRESET STYLE LOGIC # ############################## def set_style(self, style): if style == DEFAULT: self._set_default_style() elif style == MSWORD_FRIENDLY: self._set_msword_style() elif style == PLAIN_COLUMNS: self._set_columns_style() elif style == RANDOM: self._set_random_style() else: raise Exception("Invalid pre-set style!") def _set_default_style(self): self.header = True self.border = True self._hrules = FRAME self._vrules = ALL self.padding_width = 1 self.left_padding_width = 1 self.right_padding_width = 1 self.vertical_char = "|" self.horizontal_char = "-" self.junction_char = "+" def _set_msword_style(self): self.header = True self.border = True self._hrules = NONE self.padding_width = 1 self.left_padding_width = 1 self.right_padding_width = 1 self.vertical_char = "|" def _set_columns_style(self): self.header = True self.border = False self.padding_width = 1 self.left_padding_width = 0 self.right_padding_width = 8 def _set_random_style(self): # Just for fun! self.header = random.choice((True, False)) self.border = random.choice((True, False)) self._hrules = random.choice((ALL, FRAME, HEADER, NONE)) self._vrules = random.choice((ALL, FRAME, NONE)) self.left_padding_width = random.randint(0,5) self.right_padding_width = random.randint(0,5) self.vertical_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?") self.horizontal_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?") self.junction_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?") ############################## # DATA INPUT METHODS # ############################## def add_row(self, row): """Add a row to the table Arguments: row - row of data, should be a list with as many elements as the table has fields""" if self._field_names and len(row) != len(self._field_names): raise Exception("Row has incorrect number of values, (actual) %d!=%d (expected)" %(len(row),len(self._field_names))) if not self._field_names: self.field_names = [("Field %d" % (n+1)) for n in range(0,len(row))] self._rows.append(list(row)) def del_row(self, row_index): """Delete a row to the table Arguments: row_index - The index of the row you want to delete. Indexing starts at 0.""" if row_index > len(self._rows)-1: raise Exception("Cant delete row at index %d, table only has %d rows!" % (row_index, len(self._rows))) del self._rows[row_index] def add_column(self, fieldname, column, align="c", valign="t"): """Add a column to the table. Arguments: fieldname - name of the field to contain the new column of data column - column of data, should be a list with as many elements as the table has rows align - desired alignment for this column - "l" for left, "c" for centre and "r" for right valign - desired vertical alignment for new columns - "t" for top, "m" for middle and "b" for bottom""" if len(self._rows) in (0, len(column)): self._validate_align(align) self._validate_valign(valign) self._field_names.append(fieldname) self._align[fieldname] = align self._valign[fieldname] = valign for i in range(0, len(column)): if len(self._rows) < i+1: self._rows.append([]) self._rows[i].append(column[i]) else: raise Exception("Column length %d does not match number of rows %d!" % (len(column), len(self._rows))) def clear_rows(self): """Delete all rows from the table but keep the current field names""" self._rows = [] def clear(self): """Delete all rows and field names from the table, maintaining nothing but styling options""" self._rows = [] self._field_names = [] self._widths = [] ############################## # MISC PUBLIC METHODS # ############################## def copy(self): return copy.deepcopy(self) ############################## # MISC PRIVATE METHODS # ############################## def _format_value(self, field, value): if isinstance(value, int) and field in self._int_format: value = self._unicode(("%%%sd" % self._int_format[field]) % value) elif isinstance(value, float) and field in self._float_format: value = self._unicode(("%%%sf" % self._float_format[field]) % value) return self._unicode(value) def _compute_widths(self, rows, options): if options["header"]: widths = [_get_size(field)[0] for field in self._field_names] else: widths = len(self.field_names) * [0] for row in rows: for index, value in enumerate(row): fieldname = self.field_names[index] if fieldname in self.max_width: widths[index] = max(widths[index], min(_get_size(value)[0], self.max_width[fieldname])) else: widths[index] = max(widths[index], _get_size(value)[0]) self._widths = widths def _get_padding_widths(self, options): if options["left_padding_width"] is not None: lpad = options["left_padding_width"] else: lpad = options["padding_width"] if options["right_padding_width"] is not None: rpad = options["right_padding_width"] else: rpad = options["padding_width"] return lpad, rpad def _get_rows(self, options): """Return only those data rows that should be printed, based on slicing and sorting. Arguments: options - dictionary of option settings.""" # Make a copy of only those rows in the slice range rows = copy.deepcopy(self._rows[options["start"]:options["end"]]) # Sort if necessary if options["sortby"]: sortindex = self._field_names.index(options["sortby"]) # Decorate rows = [[row[sortindex]]+row for row in rows] # Sort rows.sort(reverse=options["reversesort"], key=options["sort_key"]) # Undecorate rows = [row[1:] for row in rows] return rows def _format_row(self, row, options): return [self._format_value(field, value) for (field, value) in zip(self._field_names, row)] def _format_rows(self, rows, options): return [self._format_row(row, options) for row in rows] ############################## # PLAIN TEXT STRING METHODS # ############################## def get_string(self, **kwargs): """Return string representation of table in current state. Arguments: start - index of first data row to include in output end - index of last data row to include in output PLUS ONE (list slice style) fields - names of fields (columns) to include header - print a header showing field names (True or False) border - print a border around the table (True or False) hrules - controls printing of horizontal rules after rows. Allowed values: ALL, FRAME, HEADER, NONE vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE int_format - controls formatting of integer data float_format - controls formatting of floating point data padding_width - number of spaces on either side of column data (only used if left and right paddings are None) left_padding_width - number of spaces on left hand side of column data right_padding_width - number of spaces on right hand side of column data vertical_char - single character string used to draw vertical lines horizontal_char - single character string used to draw horizontal lines junction_char - single character string used to draw line junctions sortby - name of field to sort rows by sort_key - sorting key function, applied to data points before sorting reversesort - True or False to sort in descending or ascending order print empty - if True, stringify just the header for an empty table, if False return an empty string """ options = self._get_options(kwargs) lines = [] # Don't think too hard about an empty table # Is this the desired behaviour? Maybe we should still print the header? if self.rowcount == 0 and (not options["print_empty"] or not options["border"]): return "" # Get the rows we need to print, taking into account slicing, sorting, etc. rows = self._get_rows(options) # Turn all data in all rows into Unicode, formatted as desired formatted_rows = self._format_rows(rows, options) # Compute column widths self._compute_widths(formatted_rows, options) # Add header or top of border self._hrule = self._stringify_hrule(options) if options["header"]: lines.append(self._stringify_header(options)) elif options["border"] and options["hrules"] in (ALL, FRAME): lines.append(self._hrule) # Add rows for row in formatted_rows: lines.append(self._stringify_row(row, options)) # Add bottom of border if options["border"] and options["hrules"] == FRAME: lines.append(self._hrule) return self._unicode("\n").join(lines) def _stringify_hrule(self, options): if not options["border"]: return "" lpad, rpad = self._get_padding_widths(options) if options['vrules'] in (ALL, FRAME): bits = [options["junction_char"]] else: bits = [options["horizontal_char"]] # For tables with no data or fieldnames if not self._field_names: bits.append(options["junction_char"]) return "".join(bits) for field, width in zip(self._field_names, self._widths): if options["fields"] and field not in options["fields"]: continue bits.append((width+lpad+rpad)*options["horizontal_char"]) if options['vrules'] == ALL: bits.append(options["junction_char"]) else: bits.append(options["horizontal_char"]) if options["vrules"] == FRAME: bits.pop() bits.append(options["junction_char"]) return "".join(bits) def _stringify_header(self, options): bits = [] lpad, rpad = self._get_padding_widths(options) if options["border"]: if options["hrules"] in (ALL, FRAME): bits.append(self._hrule) bits.append("\n") if options["vrules"] in (ALL, FRAME): bits.append(options["vertical_char"]) else: bits.append(" ") # For tables with no data or field names if not self._field_names: if options["vrules"] in (ALL, FRAME): bits.append(options["vertical_char"]) else: bits.append(" ") for field, width, in zip(self._field_names, self._widths): if options["fields"] and field not in options["fields"]: continue if self._header_style == "cap": fieldname = field.capitalize() elif self._header_style == "title": fieldname = field.title() elif self._header_style == "upper": fieldname = field.upper() elif self._header_style == "lower": fieldname = field.lower() else: fieldname = field bits.append(" " * lpad + self._justify(fieldname, width, self._align[field]) + " " * rpad) if options["border"]: if options["vrules"] == ALL: bits.append(options["vertical_char"]) else: bits.append(" ") # If vrules is FRAME, then we just appended a space at the end # of the last field, when we really want a vertical character if options["border"] and options["vrules"] == FRAME: bits.pop() bits.append(options["vertical_char"]) if options["border"] and options["hrules"] != NONE: bits.append("\n") bits.append(self._hrule) return "".join(bits) def _stringify_row(self, row, options): for index, field, value, width, in zip(range(0,len(row)), self._field_names, row, self._widths): # Enforce max widths lines = value.split("\n") new_lines = [] for line in lines: if _str_block_width(line) > width: line = textwrap.fill(line, width) new_lines.append(line) lines = new_lines value = "\n".join(lines) row[index] = value row_height = 0 for c in row: h = _get_size(c)[1] if h > row_height: row_height = h bits = [] lpad, rpad = self._get_padding_widths(options) for y in range(0, row_height): bits.append([]) if options["border"]: if options["vrules"] in (ALL, FRAME): bits[y].append(self.vertical_char) else: bits[y].append(" ") for field, value, width, in zip(self._field_names, row, self._widths): valign = self._valign[field] lines = value.split("\n") dHeight = row_height - len(lines) if dHeight: if valign == "m": lines = [""] * int(dHeight / 2) + lines + [""] * (dHeight - int(dHeight / 2)) elif valign == "b": lines = [""] * dHeight + lines else: lines = lines + [""] * dHeight y = 0 for l in lines: if options["fields"] and field not in options["fields"]: continue bits[y].append(" " * lpad + self._justify(l, width, self._align[field]) + " " * rpad) if options["border"]: if options["vrules"] == ALL: bits[y].append(self.vertical_char) else: bits[y].append(" ") y += 1 # If vrules is FRAME, then we just appended a space at the end # of the last field, when we really want a vertical character for y in range(0, row_height): if options["border"] and options["vrules"] == FRAME: bits[y].pop() bits[y].append(options["vertical_char"]) if options["border"] and options["hrules"]== ALL: bits[row_height-1].append("\n") bits[row_height-1].append(self._hrule) for y in range(0, row_height): bits[y] = "".join(bits[y]) return "\n".join(bits) ############################## # HTML STRING METHODS # ############################## def get_html_string(self, **kwargs): """Return string representation of HTML formatted version of table in current state. Arguments: start - index of first data row to include in output end - index of last data row to include in output PLUS ONE (list slice style) fields - names of fields (columns) to include header - print a header showing field names (True or False) border - print a border around the table (True or False) hrules - controls printing of horizontal rules after rows. Allowed values: ALL, FRAME, HEADER, NONE vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE int_format - controls formatting of integer data float_format - controls formatting of floating point data padding_width - number of spaces on either side of column data (only used if left and right paddings are None) left_padding_width - number of spaces on left hand side of column data right_padding_width - number of spaces on right hand side of column data sortby - name of field to sort rows by sort_key - sorting key function, applied to data points before sorting attributes - dictionary of name/value pairs to include as HTML attributes in the
tag xhtml - print
tags if True,
tags if false""" options = self._get_options(kwargs) if options["format"]: string = self._get_formatted_html_string(options) else: string = self._get_simple_html_string(options) return string def _get_simple_html_string(self, options): lines = [] if options["xhtml"]: linebreak = "
" else: linebreak = "
" open_tag = [] open_tag.append("") lines.append("".join(open_tag)) # Headers if options["header"]: lines.append(" ") for field in self._field_names: if options["fields"] and field not in options["fields"]: continue lines.append(" " % escape(field).replace("\n", linebreak)) lines.append(" ") # Data rows = self._get_rows(options) formatted_rows = self._format_rows(rows, options) for row in formatted_rows: lines.append(" ") for field, datum in zip(self._field_names, row): if options["fields"] and field not in options["fields"]: continue lines.append(" " % escape(datum).replace("\n", linebreak)) lines.append(" ") lines.append("
%s
%s
") return self._unicode("\n").join(lines) def _get_formatted_html_string(self, options): lines = [] lpad, rpad = self._get_padding_widths(options) if options["xhtml"]: linebreak = "
" else: linebreak = "
" open_tag = [] open_tag.append("") lines.append("".join(open_tag)) # Headers if options["header"]: lines.append(" ") for field in self._field_names: if options["fields"] and field not in options["fields"]: continue lines.append(" %s" % (lpad, rpad, escape(field).replace("\n", linebreak))) lines.append(" ") # Data rows = self._get_rows(options) formatted_rows = self._format_rows(rows, options) aligns = [] valigns = [] for field in self._field_names: aligns.append({ "l" : "left", "r" : "right", "c" : "center" }[self._align[field]]) valigns.append({"t" : "top", "m" : "middle", "b" : "bottom"}[self._valign[field]]) for row in formatted_rows: lines.append(" ") for field, datum, align, valign in zip(self._field_names, row, aligns, valigns): if options["fields"] and field not in options["fields"]: continue lines.append(" %s" % (lpad, rpad, align, valign, escape(datum).replace("\n", linebreak))) lines.append(" ") lines.append("") return self._unicode("\n").join(lines) ############################## # UNICODE WIDTH FUNCTIONS # ############################## def _char_block_width(char): # Basic Latin, which is probably the most common case #if char in xrange(0x0021, 0x007e): #if char >= 0x0021 and char <= 0x007e: if 0x0021 <= char <= 0x007e: return 1 # Chinese, Japanese, Korean (common) if 0x4e00 <= char <= 0x9fff: return 2 # Hangul if 0xac00 <= char <= 0xd7af: return 2 # Combining? if unicodedata.combining(uni_chr(char)): return 0 # Hiragana and Katakana if 0x3040 <= char <= 0x309f or 0x30a0 <= char <= 0x30ff: return 2 # Full-width Latin characters if 0xff01 <= char <= 0xff60: return 2 # CJK punctuation if 0x3000 <= char <= 0x303e: return 2 # Backspace and delete if char in (0x0008, 0x007f): return -1 # Other control characters elif char in (0x0000, 0x001f): return 0 # Take a guess return 1 def _str_block_width(val): return sum(itermap(_char_block_width, itermap(ord, _re.sub("", val)))) ############################## # TABLE FACTORIES # ############################## def from_csv(fp, field_names = None, **kwargs): dialect = csv.Sniffer().sniff(fp.read(1024)) fp.seek(0) reader = csv.reader(fp, dialect) table = PrettyTable(**kwargs) if field_names: table.field_names = field_names else: if py3k: table.field_names = [x.strip() for x in next(reader)] else: table.field_names = [x.strip() for x in reader.next()] for row in reader: table.add_row([x.strip() for x in row]) return table def from_db_cursor(cursor, **kwargs): if cursor.description: table = PrettyTable(**kwargs) table.field_names = [col[0] for col in cursor.description] for row in cursor.fetchall(): table.add_row(row) return table class TableHandler(HTMLParser): def __init__(self, **kwargs): HTMLParser.__init__(self) self.kwargs = kwargs self.tables = [] self.last_row = [] self.rows = [] self.max_row_width = 0 self.active = None self.last_content = "" self.is_last_row_header = False def handle_starttag(self,tag, attrs): self.active = tag if tag == "th": self.is_last_row_header = True def handle_endtag(self,tag): if tag in ["th", "td"]: stripped_content = self.last_content.strip() self.last_row.append(stripped_content) if tag == "tr": self.rows.append( (self.last_row, self.is_last_row_header)) self.max_row_width = max(self.max_row_width, len(self.last_row)) self.last_row = [] self.is_last_row_header = False if tag == "table": table = self.generate_table(self.rows) self.tables.append(table) self.rows = [] self.last_content = " " self.active = None def handle_data(self, data): self.last_content += data def generate_table(self, rows): """ Generates from a list of rows a PrettyTable object. """ table = PrettyTable(**self.kwargs) for row in self.rows: if len(row[0]) < self.max_row_width: appends = self.max_row_width - len(row[0]) for i in range(1,appends): row[0].append("-") if row[1] == True: self.make_fields_unique(row[0]) table.field_names = row[0] else: table.add_row(row[0]) return table def make_fields_unique(self, fields): """ iterates over the row and make each field unique """ for i in range(0, len(fields)): for j in range(i+1, len(fields)): if fields[i] == fields[j]: fields[j] += "'" def from_html(html_code, **kwargs): """ Generates a list of PrettyTables from a string of HTML code. Each in the HTML becomes one PrettyTable object. """ parser = TableHandler(**kwargs) parser.feed(html_code) return parser.tables def from_html_one(html_code, **kwargs): """ Generates a PrettyTables from a string of HTML code which contains only a single
""" tables = from_html(html_code, **kwargs) try: assert len(tables) == 1 except AssertionError: raise Exception("More than one
in provided HTML code! Use from_html instead.") return tables[0] ############################## # MAIN (TEST FUNCTION) # ############################## def main(): x = PrettyTable(["City name", "Area", "Population", "Annual Rainfall"]) x.sortby = "Population" x.reversesort = True x.int_format["Area"] = "04d" x.float_format = "6.1f" x.align["City name"] = "l" # Left align city names x.add_row(["Adelaide", 1295, 1158259, 600.5]) x.add_row(["Brisbane", 5905, 1857594, 1146.4]) x.add_row(["Darwin", 112, 120900, 1714.7]) x.add_row(["Hobart", 1357, 205556, 619.5]) x.add_row(["Sydney", 2058, 4336374, 1214.8]) x.add_row(["Melbourne", 1566, 3806092, 646.9]) x.add_row(["Perth", 5386, 1554769, 869.4]) print(x) if __name__ == "__main__": main()