Repository: giantbranch/mipsAudit
Branch: master
Commit: 8650a97bb410
Files: 5
Total size: 116.1 KB
Directory structure:
gitextract_w6dphd8d/
├── .github/
│ └── FUNDING.yml
├── README.md
├── mipsAudit.py
└── oldversion(py2)/
├── mipsAudit.py
└── prettytable.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: #
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
custom: ['http://pic.giantbranch.cn/pic/1551450728861.jpg']
================================================
FILE: README.md
================================================
# IDAPython mipsAudit
## 简介
这是一个简单的IDAPython脚本。
进一步来说是MIPS静态汇编审计辅助脚本。
可能会有bug,欢迎大家完善。
> **v3.0 更新**: 支持 IDA 7.x,8.x API 和 Python 3,新增多项高级漏洞检测功能。
## 功能
### 基础功能
1. 找到危险函数的调用处,并且高亮该行(也可以下断点,这个需要自己去源码看吧)
2. 给参数赋值处加上注释
3. 最后以表格的形式输出函数名,调用地址,参数,还有当前函数的缓冲区大小
**大家双击addr那一列的地址,即可跳到对应的地址处**


### v3.0 新增功能
#### 高级漏洞检测
| 检测类型 | 说明 |
|---------|------|
| **命令注入检测** | 追踪 system/popen/execve 参数是否来自外部输入 |
| **栈溢出检测** | 比较目标缓冲区大小与源数据长度 |
| **格式化字符串漏洞** | 检测 %n 写入原语和用户可控格式字符串 |
| **整数溢出检测** | 检测 malloc/calloc 的 size 参数来源和算术运算 |
| **双重释放检测** | 追踪 free() 调用,检测同一指针多次释放 |
| **数据流分析** | 追踪 read/recv 返回数据流向危险函数 |
| **Wrapper函数识别** | 识别封装了危险函数的自定义函数 |
#### 风险等级高亮
| 颜色 | 等级 | 说明 |
|------|------|------|
| 红色 | HIGH | 需要立即关注 |
| 橙色 | MEDIUM | 需要人工复核 |
| 绿色 | LOW | 信息提示 |
#### 结果导出
自动生成带时间戳的 HTML 报告,保存到 IDB 文件所在目录:
```
mipsAudit_results_20260129_143052.html
```
#### 配置文件支持
支持通过 `mipsAudit_config.json` 扩展自定义函数列表:
```json
{
"dangerous_functions": ["custom_strcpy"],
"command_execution_function": ["custom_exec"],
"_comment": "扩展默认函数列表"
}
```
## 审计的危险函数
```python
dangerous_functions = [
"strcpy", "strcat", "sprintf", "read", "getenv",
"gets", "scanf", "vscanf", "realpath", "access", "stat", "lstat"
]
attention_function = [
"memcpy", "strncpy", "sscanf", "strncat", "snprintf",
"vprintf", "printf", "fprintf", "vfprintf", "vsprintf",
"vsnprintf", "syslog", "memmove", "bcopy"
]
command_execution_function = [
"system", "execve", "popen", "unlink",
"execl", "execle", "execlp", "execv", "execvp",
"dlopen", "mmap", "mprotect"
]
memory_alloc_functions = [
"malloc", "calloc", "realloc", "memalign",
"valloc", "pvalloc", "aligned_alloc", "mmap"
]
memory_free_functions = [
"free", "cfree", "munmap"
]
```
## 运行流程
```
PHASE 1: Basic Function Audit # 基础危险函数审计
PHASE 2: Enhanced Vulnerability Detection # 增强漏洞检测
PHASE 3: Advanced Analysis # 高级分析(数据流、Wrapper识别)
PHASE 4: Results Summary & Export # 结果汇总与导出
```
## 使用
### 环境要求
- IDA Pro 7.0+
- Python 3.x(IDA 内置)
### 运行方式
File - Script file

选择mipsAudit.py

即可看到效果

双击地址即可跳到对应的代码处

## 更新日志
### v3.0 (2026-01)
- 支持 IDA 7.x+ API
- 迁移至 Python 3
- 新增格式化字符串漏洞检测(%n 检测)
- 新增命令注入参数来源追踪
- 新增栈溢出缓冲区大小比较
- 新增整数溢出检测(malloc/calloc size 参数)
- 新增双重释放检测
- 新增数据流分析(read/recv 返回值追踪)
- 新增 Wrapper 函数识别
- 新增基本块分析(修复跨基本块误判)
- 新增 HTML 报告导出(带时间戳)
- 新增外部 JSON 配置文件支持
- 新增扫描进度显示
- 扩展危险函数列表
### v1.0 (2018-05)
- 初始版本 by giantbranch
================================================
FILE: mipsAudit.py
================================================
# -*- coding: utf-8 -*-
# reference
# 《ida pro 权威指南》
# 《python 灰帽子》
# 《家用路由器0day漏洞挖掘》
# https://github.com/wangzery/SearchOverflow/blob/master/SearchOverflow.py
# Updated for IDA 7.x+ API and Python 3
# Enhanced with advanced vulnerability detection v3.0
import idc
import idaapi
import idautils
from prettytable import PrettyTable
from collections import defaultdict
import json
import csv
import os
import re
from datetime import datetime
# IDA 7.x+ uses BADADDR from idaapi
BADADDR = idaapi.BADADDR
DEBUG = True
# ============================================================
# Configuration - Can be overridden by external config file
# ============================================================
CONFIG_FILE = "mipsAudit_config.json"
# Default function lists (can be extended via config file)
dangerous_functions = [
"strcpy",
"strcat",
"sprintf",
"read",
"getenv",
"gets", # No boundary check - extremely dangerous
"scanf", # Format input vulnerability
"vscanf",
"realpath", # Path traversal
"access", # TOCTOU race condition
"stat", # TOCTOU race condition
"lstat",
]
attention_function = [
"memcpy",
"strncpy",
"sscanf",
"strncat",
"snprintf",
"vprintf",
"printf",
"fprintf",
"vfprintf",
"vsprintf",
"vsnprintf",
"syslog",
"memmove",
"bcopy",
]
command_execution_function = [
"system",
"execve",
"popen",
"unlink",
"execl",
"execle",
"execlp",
"execv",
"execvp",
"dlopen",
"mmap", # Memory mapping
"mprotect", # Change memory protection
]
# External input source functions (for taint tracking)
external_input_functions = [
"getenv",
"read",
"recv",
"recvfrom",
"recvmsg",
"fgets",
"fread",
"fgetc",
"gets",
"scanf",
"fscanf",
"getchar",
"getc",
"fgetws",
"getwchar",
"getline",
"getdelim",
"socket",
"accept",
"gethostbyname",
]
# Memory management functions
memory_alloc_functions = [
"malloc",
"calloc",
"realloc",
"memalign",
"valloc",
"pvalloc",
"aligned_alloc",
"mmap",
]
memory_free_functions = [
"free",
"cfree",
"munmap",
]
# Format string functions (for %n detection)
format_string_functions = [
"printf",
"fprintf",
"sprintf",
"snprintf",
"vprintf",
"vfprintf",
"vsprintf",
"vsnprintf",
"syslog",
]
# describe arg num of function
one_arg_function = [
"getenv",
"system",
"unlink",
"free",
"cfree",
"malloc",
"gets",
]
two_arg_function = [
"strcpy",
"strcat",
"popen",
"calloc",
"dlopen",
"fgets",
"access",
"stat",
"lstat",
"realpath",
"mprotect",
]
three_arg_function = [
"strncpy",
"strncat",
"memcpy",
"memmove",
"bcopy",
"execve",
"read",
"recv",
"fread",
"realloc",
]
format_function_offset_dict = {
"sprintf": 1,
"sscanf": 1,
"snprintf": 2,
"vprintf": 0,
"printf": 0,
"fprintf": 1,
"vfprintf": 1,
"vsprintf": 1,
"vsnprintf": 2,
"syslog": 1,
"scanf": 0,
}
# Risk level colors
RISK_HIGH = 0x0000ff # Red
RISK_MEDIUM = 0x00a5ff # Orange
RISK_LOW = 0x00ff00 # Green
RISK_INFO = 0xffff00 # Cyan
# ============================================================
# Enhanced Analysis - Data Structures
# ============================================================
# Store function call information for cross-reference analysis
taint_sources = {} # addr -> function_name (external input sources)
free_calls = defaultdict(list) # func_addr -> [(call_addr, arg_info), ...]
audit_results = [] # Store all findings for export
wrapper_functions = {} # Detected wrapper functions
data_flow_graph = defaultdict(list) # addr -> [(dest_addr, dest_func), ...]
# Progress tracking
total_functions = 0
processed_functions = 0
# ============================================================
# Configuration File Support
# ============================================================
def get_output_dir():
"""Get output directory - uses IDB directory in IDA environment"""
try:
idb_path = idc.get_idb_path()
if idb_path:
return os.path.dirname(idb_path)
except:
pass
return os.getcwd()
def load_config():
"""Load configuration from external JSON file"""
global dangerous_functions, attention_function, command_execution_function
global external_input_functions, memory_alloc_functions, memory_free_functions
global format_string_functions
config_path = os.path.join(get_output_dir(), CONFIG_FILE)
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Extend function lists from config
if 'dangerous_functions' in config:
dangerous_functions.extend(config['dangerous_functions'])
if 'attention_function' in config:
attention_function.extend(config['attention_function'])
if 'command_execution_function' in config:
command_execution_function.extend(config['command_execution_function'])
if 'external_input_functions' in config:
external_input_functions.extend(config['external_input_functions'])
if 'memory_alloc_functions' in config:
memory_alloc_functions.extend(config['memory_alloc_functions'])
if 'memory_free_functions' in config:
memory_free_functions.extend(config['memory_free_functions'])
if 'format_string_functions' in config:
format_string_functions.extend(config['format_string_functions'])
print("[*] Loaded configuration from %s" % config_path)
return True
except Exception as e:
print("[!] Error loading config: %s" % str(e))
return False
def save_default_config():
"""Save current configuration as default config file"""
config_path = os.path.join(get_output_dir(), CONFIG_FILE)
config = {
"dangerous_functions": [],
"attention_function": [],
"command_execution_function": [],
"external_input_functions": [],
"memory_alloc_functions": [],
"memory_free_functions": [],
"format_string_functions": [],
"_comment": "Add custom functions to extend the default lists"
}
try:
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=4)
print("[*] Created default config file: %s" % config_path)
except Exception as e:
print("[!] Error saving config: %s" % str(e))
# ============================================================
# Progress Display
# ============================================================
def show_progress(current, total, prefix="Progress"):
"""Display progress bar in IDA output"""
if total == 0:
return
percent = (current * 100) // total
bar_len = 30
filled = (current * bar_len) // total
bar = '=' * filled + '-' * (bar_len - filled)
print("\r%s: [%s] %d%% (%d/%d)" % (prefix, bar, percent, current, total), end='')
if current == total:
print() # New line when complete
# ============================================================
# Helper Functions
# ============================================================
def printFunc(func_name):
string1 = "========================================"
string2 = "========== Auditing " + func_name + " "
strlen = len(string1) - len(string2)
return string1 + "\n" + string2 + '=' * strlen + "\n" + string1
def getFuncAddr(func_name):
func_addr = idc.get_name_ea_simple(func_name)
if func_addr != BADADDR:
print(printFunc(func_name))
return func_addr
return False
def getFormatString(addr):
op_num = 1
# idc.get_operand_type Return value
# o_void 0 // No Operand
# o_reg 1 // General Register (al, ax, es, ds...) reg
# o_mem 2 // Direct Memory Reference (DATA) addr
# o_phrase 3 // Memory Ref [Base Reg + Index Reg] phrase
# o_displ 4 // Memory Reg [Base Reg + Index Reg + Displacement] phrase+addr
# o_imm 5 // Immediate Value value
# o_far 6 // Immediate Far Address (CODE) addr
# o_near 7 // Immediate Near Address (CODE) addr
# o_idpspec0 8 // IDP specific type
# o_idpspec1 9 // IDP specific type
# o_idpspec2 10 // IDP specific type
# o_idpspec3 11 // IDP specific type
# o_idpspec4 12 // IDP specific type
# o_idpspec5 13 // IDP specific type
if idc.get_operand_type(addr, op_num) != 5:
op_num = op_num + 1
if idc.get_operand_type(addr, op_num) != 5:
return "get fail"
op_string = idc.print_operand(addr, op_num).split(" ")[0].split("+")[0].split("-")[0].replace("(", "")
string_addr = idc.get_name_ea_simple(op_string)
if string_addr == BADADDR:
return "get fail"
string_content = idc.get_strlit_contents(string_addr)
if string_content is None:
return "get fail"
if isinstance(string_content, bytes):
string_content = string_content.decode('utf-8', errors='replace')
return [string_addr, string_content]
def getArgAddr(start_addr, regNum):
mipscondition = ["bn", "be" , "bg", "bl"]
scan_deep = 50
count = 0
reg = "$a" + str(regNum)
# try to get in the next (code references from this address)
for next_addr in idautils.CodeRefsFrom(start_addr, 0):
if next_addr != BADADDR and reg == idc.print_operand(next_addr, 0):
return next_addr
# try to get before (code references to this address)
before_addr = start_addr
for ref_addr in idautils.CodeRefsTo(start_addr, 0):
before_addr = ref_addr
break
if before_addr == start_addr:
before_addr = idc.prev_head(start_addr)
while before_addr != BADADDR:
if reg == idc.print_operand(before_addr, 0):
Mnemonics = idc.print_insn_mnem(before_addr)
if Mnemonics[0:2] in mipscondition:
pass
elif Mnemonics[0:1] == "j":
pass
else:
return before_addr
count = count + 1
if count > scan_deep:
break
before_addr = idc.prev_head(before_addr)
return BADADDR
def getArg(start_addr, regNum):
mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"]
arg_addr = getArgAddr(start_addr, regNum)
if arg_addr != BADADDR:
Mnemonics = idc.print_insn_mnem(arg_addr)
if Mnemonics[0:3] == "add":
if idc.print_operand(arg_addr, 2) == "":
arg = idc.print_operand(arg_addr, 0) + "+" + idc.print_operand(arg_addr, 1)
else:
arg = idc.print_operand(arg_addr, 1) + "+" + idc.print_operand(arg_addr, 2)
elif Mnemonics[0:3] == "sub":
if idc.print_operand(arg_addr, 2) == "":
arg = idc.print_operand(arg_addr, 0) + "-" + idc.print_operand(arg_addr, 1)
else:
arg = idc.print_operand(arg_addr, 1) + "-" + idc.print_operand(arg_addr, 2)
elif Mnemonics in mipsmov:
arg = idc.print_operand(arg_addr, 1)
else:
arg = idc.GetDisasm(arg_addr).split("#")[0]
idc.set_cmt(arg_addr, "addr: 0x%x " % start_addr + "-------> arg" + str((int(regNum)+1)) + " : " + arg, 0)
return arg
else:
return "get fail"
def audit(func_name):
func_addr = getFuncAddr(func_name)
if func_addr == False:
return False
# get arg num and set table
if func_name in one_arg_function:
arg_num = 1
elif func_name in two_arg_function:
arg_num = 2
elif func_name in three_arg_function:
arg_num = 3
elif func_name in format_function_offset_dict:
arg_num = format_function_offset_dict[func_name] + 1
else:
print("The %s function didn't write in the describe arg num of function array, please add it to, such as add to `two_arg_function` array" % func_name)
return
# mispcall = ["jal", "jalr", "bal", "jr"]
table_head = ["func_name", "addr"]
for num in range(0, arg_num):
table_head.append("arg"+str(num+1))
if func_name in format_function_offset_dict:
table_head.append("format&value[string_addr, num of '%', fmt_arg...]")
table_head.append("local_buf_size")
table = PrettyTable(table_head)
# get references to function (xrefs)
for call_addr in idautils.CodeRefsTo(func_addr, 0):
# set color - green (red=0x0000ff, blue=0xff0000)
idc.set_color(call_addr, idc.CIC_ITEM, 0x00ff00)
# set break point
# idc.add_bpt(call_addr)
# idc.del_bpt(call_addr)
Mnemonics = idc.print_insn_mnem(call_addr)
if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b":
if func_name in format_function_offset_dict:
info = auditFormat(call_addr, func_name, arg_num)
else:
info = auditAddr(call_addr, func_name, arg_num)
table.add_row(info)
print(table)
# data_addr = DfirstB(func_addr)
# while data_addr != BADADDR:
# Mnemonics = GetMnem(data_addr)
# if DEBUG:
# print "Data Mnemonics : %s" % GetMnem(data_addr)
# print "Data addr : 0x %s" % data_addr
# data_addr = DnextB(func_addr, data_addr)
def auditAddr(call_addr, func_name, arg_num):
addr = "0x%x" % call_addr
ret_list = [func_name, addr]
# local buf size
local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE)
if local_buf_size == BADADDR:
local_buf_size = "get fail"
else:
local_buf_size = "0x%x" % local_buf_size
# get arg
for num in range(0, arg_num):
ret_list.append(getArg(call_addr, num))
ret_list.append(local_buf_size)
return ret_list
def auditFormat(call_addr, func_name, arg_num):
addr = "0x%x" % call_addr
ret_list = [func_name, addr]
# local buf size
local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE)
if local_buf_size == BADADDR:
local_buf_size = "get fail"
else:
local_buf_size = "0x%x" % local_buf_size
# get arg
for num in range(0, arg_num):
ret_list.append(getArg(call_addr, num))
arg_addr = getArgAddr(call_addr, format_function_offset_dict[func_name])
string_and_addr = getFormatString(arg_addr)
format_and_value = []
if string_and_addr == "get fail":
ret_list.append("get fail")
else:
string_addr = "0x%x" % string_and_addr[0]
format_and_value.append(string_addr)
string = string_and_addr[1]
fmt_num = string.count("%")
format_and_value.append(fmt_num)
# mips arg reg is from a0 to a3
if fmt_num > 3:
fmt_num = fmt_num - format_function_offset_dict[func_name] - 1
for num in range(0, fmt_num):
if arg_num + num > 3:
break
format_and_value.append(getArg(call_addr, arg_num + num))
ret_list.append(format_and_value)
ret_list.append(local_buf_size)
return ret_list
# ============================================================
# Enhanced Detection Functions
# ============================================================
def getCallingFunction(addr):
"""Get the function containing the given address"""
func = idaapi.get_func(addr)
if func:
return func.start_ea
return BADADDR
def traceArgSource(start_addr, regNum, depth=10):
"""
Trace the source of an argument to detect external input
Returns: (source_type, source_info)
source_type: 'external_input', 'static_string', 'stack_var', 'unknown'
"""
if depth <= 0:
return ('unknown', None)
mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"]
arg_addr = getArgAddr(start_addr, regNum)
if arg_addr == BADADDR:
return ('unknown', None)
Mnemonics = idc.print_insn_mnem(arg_addr)
operand1 = idc.print_operand(arg_addr, 1)
# Check if loading from a static string
if Mnemonics in ["la", "lui", "li"]:
str_addr = idc.get_name_ea_simple(operand1.split("+")[0].split("-")[0].replace("(", ""))
if str_addr != BADADDR:
str_content = idc.get_strlit_contents(str_addr)
if str_content:
if isinstance(str_content, bytes):
str_content = str_content.decode('utf-8', errors='replace')
return ('static_string', {'addr': str_addr, 'value': str_content, 'len': len(str_content)})
# Check if it's a stack variable
if "sp" in operand1 or "fp" in operand1:
return ('stack_var', {'operand': operand1})
# Check if moved from another register - trace further
if Mnemonics == "move":
src_reg = operand1
if src_reg.startswith("$v"): # Return value from function call
# Look for preceding function call
scan_addr = idc.prev_head(arg_addr)
scan_count = 0
while scan_addr != BADADDR and scan_count < 20:
mnem = idc.print_insn_mnem(scan_addr)
if mnem in ["jal", "jalr"]:
# Get the called function name
call_target = idc.print_operand(scan_addr, 0)
if call_target in external_input_functions:
return ('external_input', {'function': call_target, 'addr': scan_addr})
break
scan_count += 1
scan_addr = idc.prev_head(scan_addr)
return ('unknown', None)
def checkCommandInjection(call_addr, func_name):
"""
Check if command execution function has external input as argument
Returns risk assessment
"""
results = []
# For system(), popen() - check first argument (command string)
if func_name in ["system"]:
source_type, source_info = traceArgSource(call_addr, 0, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': 'Command injection - arg from %s()' % source_info['function'],
'detail': source_info
})
elif source_type == 'stack_var':
results.append({
'risk': 'MEDIUM',
'issue': 'Command from stack variable (needs manual review)',
'detail': source_info
})
elif source_type == 'static_string':
results.append({
'risk': 'LOW',
'issue': 'Static command string',
'detail': source_info
})
elif func_name in ["popen"]:
source_type, source_info = traceArgSource(call_addr, 0, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': 'Popen injection - arg from %s()' % source_info['function'],
'detail': source_info
})
elif func_name in ["execve", "execl", "execle", "execlp", "execv", "execvp"]:
source_type, source_info = traceArgSource(call_addr, 0, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': 'Exec injection - arg from %s()' % source_info['function'],
'detail': source_info
})
return results
def checkStackOverflow(call_addr, func_name):
"""
Check for potential stack buffer overflow
Compares destination buffer size with source length
"""
results = []
local_buf_size = idc.get_func_attr(call_addr, idc.FUNCATTR_FRSIZE)
if func_name in ["strcpy", "strcat"]:
# Check source (arg1) for static string length
source_type, source_info = traceArgSource(call_addr, 1, depth=15)
if source_type == 'static_string' and source_info:
src_len = source_info.get('len', 0)
if local_buf_size != BADADDR and src_len > 0:
if src_len > local_buf_size:
results.append({
'risk': 'HIGH',
'issue': 'Buffer overflow: src_len(%d) > frame_size(0x%x)' % (src_len, local_buf_size),
'detail': source_info
})
elif src_len > local_buf_size // 2:
results.append({
'risk': 'MEDIUM',
'issue': 'Potential overflow: src_len(%d) close to frame_size(0x%x)' % (src_len, local_buf_size),
'detail': source_info
})
elif source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': '%s with external input from %s()' % (func_name, source_info['function']),
'detail': source_info
})
elif func_name in ["sprintf"]:
# Check format string for potential overflow
arg_addr = getArgAddr(call_addr, 1) # format string is arg1
fmt_info = getFormatString(arg_addr)
if fmt_info != "get fail":
fmt_str = fmt_info[1]
# Check for %s without width limit
if '%s' in fmt_str:
results.append({
'risk': 'MEDIUM',
'issue': 'sprintf with %%s (unbounded string copy)',
'detail': {'format': fmt_str}
})
elif func_name in ["memcpy", "strncpy", "strncat"]:
# Check size parameter (arg2)
source_type, source_info = traceArgSource(call_addr, 2, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': '%s size from external input %s()' % (func_name, source_info['function']),
'detail': source_info
})
return results
def checkIntegerOverflow(call_addr, func_name):
"""
Check for potential integer overflow in memory allocation
"""
results = []
if func_name in ["malloc", "calloc", "realloc", "memalign"]:
# Check size argument
size_arg_idx = 0 if func_name == "malloc" else 1
if func_name == "calloc":
# calloc(count, size) - check both arguments
for idx in [0, 1]:
source_type, source_info = traceArgSource(call_addr, idx, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': 'calloc arg%d from external input %s()' % (idx, source_info['function']),
'detail': source_info
})
else:
source_type, source_info = traceArgSource(call_addr, size_arg_idx, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': '%s size from external input %s()' % (func_name, source_info['function']),
'detail': source_info
})
# Check for arithmetic operations on size (potential integer overflow)
arg_addr = getArgAddr(call_addr, size_arg_idx)
if arg_addr != BADADDR:
# Scan backward for arithmetic operations
scan_addr = arg_addr
scan_count = 0
while scan_addr != BADADDR and scan_count < 10:
mnem = idc.print_insn_mnem(scan_addr)
if mnem in ["mul", "mult", "sll", "add", "addu", "addi", "addiu"]:
results.append({
'risk': 'MEDIUM',
'issue': 'Arithmetic (%s) before %s - check for integer overflow' % (mnem, func_name),
'detail': {'addr': "0x%x" % scan_addr, 'instruction': idc.GetDisasm(scan_addr)}
})
break
scan_count += 1
scan_addr = idc.prev_head(scan_addr)
return results
def auditFreeCall(func_name):
"""
Audit free() calls for potential double-free vulnerabilities
"""
func_addr = idc.get_name_ea_simple(func_name)
if func_addr == BADADDR:
return
print(printFunc(func_name + " (Double-Free Analysis)"))
# Group by containing function
calls_by_func = defaultdict(list)
for call_addr in idautils.CodeRefsTo(func_addr, 0):
Mnemonics = idc.print_insn_mnem(call_addr)
if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b":
containing_func = getCallingFunction(call_addr)
arg_info = getArg(call_addr, 0) # First argument to free
calls_by_func[containing_func].append({
'addr': call_addr,
'arg': arg_info
})
# Analyze each function for potential double-free
table = PrettyTable(["function", "free_count", "addresses", "args", "risk"])
for func_start, calls in calls_by_func.items():
func_name_str = idc.get_func_name(func_start) or ("0x%x" % func_start)
addrs = [("0x%x" % c['addr']) for c in calls]
args = [c['arg'] for c in calls]
# Risk assessment
risk = "LOW"
if len(calls) > 1:
# Check if same argument pattern appears multiple times
arg_counts = defaultdict(int)
for arg in args:
arg_counts[arg] += 1
for arg, count in arg_counts.items():
if count > 1 and arg != "get fail":
risk = "HIGH"
idc.set_color(calls[0]['addr'], idc.CIC_ITEM, RISK_HIGH)
break
else:
if len(calls) > 2:
risk = "MEDIUM"
table.add_row([func_name_str, len(calls), ", ".join(addrs), ", ".join(args), risk])
print(table)
def auditEnhanced(func_name):
"""
Enhanced audit with additional vulnerability checks
"""
global audit_results
func_addr = idc.get_name_ea_simple(func_name)
if func_addr == BADADDR:
return False
print(printFunc(func_name + " (Enhanced Analysis)"))
# Determine check type based on function
check_cmd_injection = func_name in command_execution_function
check_overflow = func_name in dangerous_functions + attention_function
check_int_overflow = func_name in memory_alloc_functions
check_format_string = func_name in format_string_functions
table = PrettyTable(["addr", "risk", "issue", "detail"])
for call_addr in idautils.CodeRefsTo(func_addr, 0):
Mnemonics = idc.print_insn_mnem(call_addr)
if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b":
findings = []
if check_cmd_injection:
findings.extend(checkCommandInjection(call_addr, func_name))
if check_overflow:
findings.extend(checkStackOverflow(call_addr, func_name))
if check_int_overflow:
findings.extend(checkIntegerOverflow(call_addr, func_name))
if check_format_string:
findings.extend(checkFormatStringVuln(call_addr, func_name))
# Add findings to table and global results
for finding in findings:
risk = finding['risk']
# Set color based on risk
if risk == 'HIGH':
idc.set_color(call_addr, idc.CIC_ITEM, RISK_HIGH)
elif risk == 'MEDIUM':
idc.set_color(call_addr, idc.CIC_ITEM, RISK_MEDIUM)
else:
idc.set_color(call_addr, idc.CIC_ITEM, RISK_LOW)
detail_str = str(finding.get('detail', ''))[:50]
table.add_row(["0x%x" % call_addr, risk, finding['issue'], detail_str])
# Add to global results for export
finding['address'] = "0x%x" % call_addr
finding['function'] = func_name
audit_results.append(finding)
if table.rowcount > 0:
print(table)
else:
print(" No enhanced findings for %s" % func_name)
def collectTaintSources():
"""
Collect all external input sources in the binary for taint analysis
"""
global taint_sources
taint_sources = {}
print("\n[*] Collecting external input sources...")
for func_name in external_input_functions:
func_addr = idc.get_name_ea_simple(func_name)
if func_addr == BADADDR:
continue
for call_addr in idautils.CodeRefsTo(func_addr, 0):
Mnemonics = idc.print_insn_mnem(call_addr)
if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b":
taint_sources[call_addr] = func_name
print(" Found %d external input call sites" % len(taint_sources))
return taint_sources
# ============================================================
# Basic Block Analysis (Fix cross-block issues)
# ============================================================
def getBasicBlockBounds(addr):
"""Get the basic block boundaries containing the address"""
func = idaapi.get_func(addr)
if not func:
return (BADADDR, BADADDR)
try:
flowchart = idaapi.FlowChart(func)
for block in flowchart:
if block.start_ea <= addr < block.end_ea:
return (block.start_ea, block.end_ea)
except:
pass
return (BADADDR, BADADDR)
def getArgAddrInBlock(start_addr, regNum):
"""
Improved argument address detection within basic block boundaries
Avoids crossing basic block boundaries which could lead to false positives
"""
mipscondition = ["bn", "be", "bg", "bl"]
scan_deep = 50
count = 0
reg = "$a" + str(regNum)
# Get basic block bounds
block_start, block_end = getBasicBlockBounds(start_addr)
# try to get in the next (code references from this address)
for next_addr in idautils.CodeRefsFrom(start_addr, 0):
if next_addr != BADADDR and reg == idc.print_operand(next_addr, 0):
return next_addr
# try to get before (within same basic block)
before_addr = idc.prev_head(start_addr)
while before_addr != BADADDR:
# Stop if we cross basic block boundary
if block_start != BADADDR and before_addr < block_start:
break
if reg == idc.print_operand(before_addr, 0):
Mnemonics = idc.print_insn_mnem(before_addr)
if Mnemonics[0:2] in mipscondition:
pass
elif Mnemonics[0:1] == "j":
pass
else:
return before_addr
count = count + 1
if count > scan_deep:
break
before_addr = idc.prev_head(before_addr)
return BADADDR
# ============================================================
# Format String Vulnerability Detection
# ============================================================
def checkFormatStringVuln(call_addr, func_name):
"""
Enhanced format string vulnerability detection
- Detects %n format specifier (write primitive)
- Detects user-controlled format string
"""
results = []
if func_name not in format_function_offset_dict:
return results
fmt_arg_idx = format_function_offset_dict[func_name]
arg_addr = getArgAddrInBlock(call_addr, fmt_arg_idx)
if arg_addr == BADADDR:
return results
# Try to get the format string
fmt_info = getFormatString(arg_addr)
if fmt_info != "get fail":
fmt_str = fmt_info[1]
# Check for %n (write primitive - HIGH risk)
if '%n' in fmt_str or '%hn' in fmt_str or '%hhn' in fmt_str or '%ln' in fmt_str:
results.append({
'risk': 'HIGH',
'issue': 'Format string with %%n write primitive',
'detail': {'format': fmt_str, 'addr': "0x%x" % fmt_info[0]},
'type': 'format_string'
})
# Check for multiple format specifiers (potential overflow)
fmt_count = len(re.findall(r'%[^%]', fmt_str))
if fmt_count > 10:
results.append({
'risk': 'MEDIUM',
'issue': 'Format string with many specifiers (%d)' % fmt_count,
'detail': {'format': fmt_str[:50] + '...', 'count': fmt_count},
'type': 'format_string'
})
else:
# Format string is not static - check if user controlled
source_type, source_info = traceArgSource(call_addr, fmt_arg_idx, depth=15)
if source_type == 'external_input':
results.append({
'risk': 'HIGH',
'issue': 'User-controlled format string from %s()' % source_info['function'],
'detail': source_info,
'type': 'format_string'
})
elif source_type == 'stack_var':
results.append({
'risk': 'MEDIUM',
'issue': 'Format string from stack variable (needs review)',
'detail': source_info,
'type': 'format_string'
})
return results
# ============================================================
# Data Flow Analysis - Forward Tracking
# ============================================================
def traceDataFlowForward(start_addr, src_func, max_depth=5):
"""
Trace where data from external input flows to
Tracks return values from read/recv etc. to dangerous sinks
"""
results = []
if max_depth <= 0:
return results
# Get the function containing this call
func = idaapi.get_func(start_addr)
if not func:
return results
# For read/recv, the buffer is arg1 ($a1)
# Track where this buffer is used
buffer_arg_idx = 1 if src_func in ["read", "recv", "recvfrom", "fread"] else 0
buffer_operand = getArg(start_addr, buffer_arg_idx)
if buffer_operand == "get fail":
return results
# Scan forward in the function to find uses of this buffer
current_addr = idc.next_head(start_addr)
func_end = func.end_ea
scan_count = 0
max_scan = 100
while current_addr < func_end and current_addr != BADADDR and scan_count < max_scan:
mnem = idc.print_insn_mnem(current_addr)
# Check if this is a function call
if mnem in ["jal", "jalr"]:
call_target = idc.print_operand(current_addr, 0)
# Check if any dangerous function is called with our tainted buffer
all_dangerous = dangerous_functions + command_execution_function + format_string_functions
if call_target in all_dangerous:
# Check if our buffer is passed as an argument
for arg_idx in range(4): # MIPS uses $a0-$a3
arg_operand = getArg(current_addr, arg_idx)
if buffer_operand in arg_operand or arg_operand in buffer_operand:
risk = 'HIGH' if call_target in command_execution_function else 'MEDIUM'
results.append({
'risk': risk,
'issue': 'Tainted data from %s() flows to %s()' % (src_func, call_target),
'detail': {
'source': "0x%x" % start_addr,
'sink': "0x%x" % current_addr,
'sink_func': call_target,
'buffer': buffer_operand
},
'type': 'data_flow'
})
break
current_addr = idc.next_head(current_addr)
scan_count += 1
return results
def analyzeDataFlow():
"""
Perform data flow analysis from external inputs to dangerous sinks
"""
global data_flow_graph
results = []
print("\n[*] Analyzing data flow from external inputs...")
for call_addr, func_name in taint_sources.items():
if func_name in ["read", "recv", "recvfrom", "fread", "fgets", "gets"]:
flow_results = traceDataFlowForward(call_addr, func_name)
results.extend(flow_results)
# Store in graph for visualization
for r in flow_results:
if 'detail' in r and 'sink' in r['detail']:
data_flow_graph[call_addr].append((r['detail']['sink'], r['detail']['sink_func']))
print(" Found %d data flow issues" % len(results))
return results
# ============================================================
# Wrapper Function Detection
# ============================================================
def detectWrapperFunctions():
"""
Detect wrapper functions that call dangerous functions
e.g., my_strcpy that internally calls strcpy
"""
global wrapper_functions
wrapper_functions = {}
print("\n[*] Detecting wrapper functions...")
all_dangerous = dangerous_functions + command_execution_function
for dangerous_func in all_dangerous:
func_addr = idc.get_name_ea_simple(dangerous_func)
if func_addr == BADADDR:
continue
# Find all callers of this dangerous function
for call_addr in idautils.CodeRefsTo(func_addr, 0):
mnem = idc.print_insn_mnem(call_addr)
if mnem[0:1] != "j" and mnem[0:1] != "b":
continue
# Get the function containing this call
caller_func = idaapi.get_func(call_addr)
if not caller_func:
continue
caller_name = idc.get_func_name(caller_func.start_ea)
if not caller_name:
continue
# Heuristics for wrapper detection:
# 1. Function name contains common wrapper patterns
# 2. Function is small (likely just a wrapper)
# 3. Function forwards arguments directly
is_wrapper = False
wrapper_type = None
# Check name patterns
wrapper_patterns = ['my_', 'safe_', 'wrap_', '_wrapper', '_safe', 'do_', 'internal_']
for pattern in wrapper_patterns:
if pattern in caller_name.lower():
is_wrapper = True
wrapper_type = 'name_match'
break
# Check function size (small functions are likely wrappers)
func_size = caller_func.end_ea - caller_func.start_ea
if func_size < 100 and not is_wrapper: # Less than ~25 instructions
# Count how many other calls this function makes
call_count = 0
for addr in idautils.FuncItems(caller_func.start_ea):
m = idc.print_insn_mnem(addr)
if m in ["jal", "jalr"]:
call_count += 1
if call_count <= 2: # Mostly just calls the dangerous function
is_wrapper = True
wrapper_type = 'small_func'
if is_wrapper:
if caller_func.start_ea not in wrapper_functions:
wrapper_functions[caller_func.start_ea] = {
'name': caller_name,
'wraps': [],
'type': wrapper_type
}
wrapper_functions[caller_func.start_ea]['wraps'].append(dangerous_func)
print(" Found %d potential wrapper functions" % len(wrapper_functions))
return wrapper_functions
def auditWrapperFunctions():
"""
Audit detected wrapper functions
"""
if not wrapper_functions:
detectWrapperFunctions()
if not wrapper_functions:
print(" No wrapper functions detected")
return
print(printFunc("Wrapper Functions"))
table = PrettyTable(["wrapper_name", "address", "wraps", "detection_type"])
for func_addr, info in wrapper_functions.items():
idc.set_color(func_addr, idc.CIC_FUNC, RISK_MEDIUM)
table.add_row([
info['name'],
"0x%x" % func_addr,
", ".join(info['wraps']),
info['type']
])
print(table)
# Also audit calls to wrapper functions
print("\n Calls to wrapper functions:")
wrapper_table = PrettyTable(["wrapper", "call_addr", "caller_func"])
for func_addr, info in wrapper_functions.items():
for call_addr in idautils.CodeRefsTo(func_addr, 0):
mnem = idc.print_insn_mnem(call_addr)
if mnem[0:1] == "j" or mnem[0:1] == "b":
caller_func = idc.get_func_name(call_addr)
idc.set_color(call_addr, idc.CIC_ITEM, RISK_MEDIUM)
wrapper_table.add_row([info['name'], "0x%x" % call_addr, caller_func or "unknown"])
if wrapper_table.rowcount > 0:
print(wrapper_table)
# ============================================================
# Result Export Functions
# ============================================================
def addFinding(finding):
"""Add a finding to the global results list"""
global audit_results
audit_results.append(finding)
def exportResultsJSON(filename):
"""Export audit results to JSON file"""
try:
output_path = os.path.join(get_output_dir(), filename)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(audit_results, f, indent=2, ensure_ascii=False)
print("[*] Results exported to: %s" % output_path)
return True
except Exception as e:
print("[!] Error exporting JSON: %s" % str(e))
return False
def exportResultsCSV(filename):
"""Export audit results to CSV file"""
try:
output_path = os.path.join(get_output_dir(), filename)
with open(output_path, 'w', newline='', encoding='utf-8') as f:
if audit_results:
writer = csv.DictWriter(f, fieldnames=['risk', 'issue', 'type', 'address', 'function', 'detail'])
writer.writeheader()
for result in audit_results:
row = {
'risk': result.get('risk', ''),
'issue': result.get('issue', ''),
'type': result.get('type', ''),
'address': result.get('address', ''),
'function': result.get('function', ''),
'detail': str(result.get('detail', ''))[:100]
}
writer.writerow(row)
print("[*] Results exported to: %s" % output_path)
return True
except Exception as e:
print("[!] Error exporting CSV: %s" % str(e))
return False
def exportResultsHTML(filename):
"""Export audit results to HTML file"""
try:
output_path = os.path.join(get_output_dir(), filename)
html_content = """
MIPS Audit Report
MIPS Security Audit Report
Total Findings: %d |
HIGH: %d |
MEDIUM: %d |
LOW: %d
| Risk |
Issue |
Type |
Address |
Function |
Detail |
""" % (
len(audit_results),
len([r for r in audit_results if r.get('risk') == 'HIGH']),
len([r for r in audit_results if r.get('risk') == 'MEDIUM']),
len([r for r in audit_results if r.get('risk') == 'LOW'])
)
for result in audit_results:
risk = result.get('risk', '')
html_content += """
| %s |
%s |
%s |
%s |
%s |
%s |
""" % (
risk,
risk,
result.get('issue', ''),
result.get('type', ''),
result.get('address', ''),
result.get('function', ''),
str(result.get('detail', ''))[:100]
)
html_content += """
"""
with open(output_path, 'w', encoding='utf-8') as f:
f.write(html_content)
print("[*] Results exported to: %s" % output_path)
return True
except Exception as e:
print("[!] Error exporting HTML: %s" % str(e))
return False
def exportResults(base_filename="mipsAudit_results"):
"""Export results as HTML with timestamp"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = "%s_%s.html" % (base_filename, timestamp)
exportResultsHTML(filename)
return filename
def mipsAudit():
"""Main audit function with all enhanced features"""
global audit_results, total_functions, processed_functions
# Reset global state
audit_results = []
# the word create with figlet
start = '''
_ _ _ _ _
_ __ ___ (_)_ __ ___ / \ _ _ __| (_) |_
| '_ ` _ \| | '_ \/ __| / _ \| | | |/ _` | | __|
| | | | | | | |_) \__ \/ ___ \ |_| | (_| | | |_
|_| |_| |_|_| .__/|___/_/ \_\__,_|\__,_|_|\__|
|_|
code by giantbranch 2018.05
updated for IDA 7.x+ & Python 3
enhanced detection v3.0 20260129
'''
print(start)
# Load external configuration if available
load_config()
# Calculate total functions for progress display
all_functions = (dangerous_functions + attention_function +
command_execution_function + memory_alloc_functions +
memory_free_functions + format_string_functions)
total_functions = len(set(all_functions))
processed_functions = 0
# Collect taint sources first for enhanced analysis
collectTaintSources()
print("\n" + "=" * 60)
print(" PHASE 1: Basic Function Audit")
print("=" * 60)
print("\nAuditing dangerous functions ......")
for i, func_name in enumerate(dangerous_functions):
audit(func_name)
show_progress(i + 1, len(dangerous_functions), "Dangerous funcs")
print("\nAuditing attention function ......")
for i, func_name in enumerate(attention_function):
audit(func_name)
show_progress(i + 1, len(attention_function), "Attention funcs")
print("\nAuditing command execution function ......")
for i, func_name in enumerate(command_execution_function):
audit(func_name)
show_progress(i + 1, len(command_execution_function), "Cmd exec funcs")
print("\n" + "=" * 60)
print(" PHASE 2: Enhanced Vulnerability Detection")
print("=" * 60)
print("\n[*] Enhanced analysis: Command Injection Detection")
for func_name in command_execution_function:
auditEnhanced(func_name)
print("\n[*] Enhanced analysis: Stack Overflow Detection")
for func_name in dangerous_functions:
auditEnhanced(func_name)
print("\n[*] Enhanced analysis: Integer Overflow Detection")
for func_name in memory_alloc_functions:
auditEnhanced(func_name)
print("\n[*] Enhanced analysis: Format String Vulnerability Detection")
for func_name in format_string_functions:
auditEnhanced(func_name)
print("\n[*] Enhanced analysis: Double-Free Detection")
for func_name in memory_free_functions:
auditFreeCall(func_name)
print("\n" + "=" * 60)
print(" PHASE 3: Advanced Analysis")
print("=" * 60)
# Data flow analysis
flow_results = analyzeDataFlow()
if flow_results:
print("\nData Flow Analysis Results:")
table = PrettyTable(["risk", "issue", "source", "sink"])
for r in flow_results:
table.add_row([
r['risk'],
r['issue'],
r['detail'].get('source', ''),
r['detail'].get('sink', '')
])
audit_results.append(r)
print(table)
# Wrapper function detection
print("\n[*] Detecting and auditing wrapper functions...")
auditWrapperFunctions()
print("\n" + "=" * 60)
print(" PHASE 4: Results Summary & Export")
print("=" * 60)
# Summary
high_count = len([r for r in audit_results if r.get('risk') == 'HIGH'])
medium_count = len([r for r in audit_results if r.get('risk') == 'MEDIUM'])
low_count = len([r for r in audit_results if r.get('risk') == 'LOW'])
print("\n[*] Audit Summary:")
print(" Total findings: %d" % len(audit_results))
print(" HIGH risk: %d" % high_count)
print(" MEDIUM risk: %d" % medium_count)
print(" LOW risk: %d" % low_count)
# Export results
export_file = None
if audit_results:
print("\n[*] Exporting results...")
export_file = exportResults()
print("\nExported: %s" % export_file)
print("\n" + "=" * 60)
print(" Finished! Enjoy the result ~")
print("=" * 60)
# Check processor architecture (may be useful in future)
# info = idaapi.get_inf_structure()
#
# if info.is_64bit():
# bits = 64
# elif info.is_32bit():
# bits = 32
# else:
# bits = 16
#
# try:
# is_be = info.is_be()
# except:
# is_be = info.mf
# endian = "big" if is_be else "little"
#
# print('Processor: {}, {}bit, {} endian'.format(info.procName, bits, endian))
# # Result: Processor: mipsr, 32bit, big endian
mipsAudit()
================================================
FILE: oldversion(py2)/mipsAudit.py
================================================
# -*- coding: utf-8 -*-
# reference
# 《ida pro 权威指南》
# 《python 灰帽子》
# 《家用路由器0day漏洞挖掘》
# https://github.com/wangzery/SearchOverflow/blob/master/SearchOverflow.py
from idaapi import *
from prettytable import PrettyTable
DEBUG = True
# fgetc,fgets,fread,fprintf,
# vspritnf
# set function_name
dangerous_functions = [
"strcpy",
"strcat",
"sprintf",
"read",
"getenv"
]
attention_function = [
"memcpy",
"strncpy",
"sscanf",
"strncat",
"snprintf",
"vprintf",
"printf"
]
command_execution_function = [
"system",
"execve",
"popen",
"unlink"
]
# describe arg num of function
one_arg_function = [
"getenv",
"system",
"unlink"
]
two_arg_function = [
"strcpy",
"strcat",
"popen"
]
three_arg_function = [
"strncpy",
"strncat",
"memcpy",
"execve",
"read"
]
format_function_offset_dict = {
"sprintf":1,
"sscanf":1,
"snprintf":2,
"vprintf":0,
"printf":0
}
def printFunc(func_name):
string1 = "========================================"
string2 = "========== Aduiting " + func_name + " "
strlen = len(string1) - len(string2)
return string1 + "\n" + string2 + '=' * strlen + "\n" + string1
def getFuncAddr(func_name):
func_addr = LocByName(func_name)
if func_addr != BADADDR:
print printFunc(func_name)
# print func_name + " Addr : 0x %x" % func_addr
return func_addr
return False
def getFormatString(addr):
op_num = 1
# GetOpType Return value
#define o_void 0 // No Operand ----------
#define o_reg 1 // General Register (al, ax, es, ds...) reg
#define o_mem 2 // Direct Memory Reference (DATA) addr
#define o_phrase 3 // Memory Ref [Base Reg + Index Reg] phrase
#define o_displ 4 // Memory Reg [Base Reg + Index Reg + Displacement] phrase+addr
#define o_imm 5 // Immediate Value value
#define o_far 6 // Immediate Far Address (CODE) addr
#define o_near 7 // Immediate Near Address (CODE) addr
#define o_idpspec0 8 // IDP specific type
#define o_idpspec1 9 // IDP specific type
#define o_idpspec2 10 // IDP specific type
#define o_idpspec3 11 // IDP specific type
#define o_idpspec4 12 // IDP specific type
#define o_idpspec5 13 // IDP specific type
# 如果第二个不是立即数则下一个
if(GetOpType(addr ,op_num) != 5):
op_num = op_num + 1
if GetOpType(addr ,op_num) != 5:
return "get fail"
op_string = GetOpnd(addr, op_num).split(" ")[0].split("+")[0].split("-")[0].replace("(", "")
string_addr = LocByName(op_string)
if string_addr == BADADDR:
return "get fail"
string = str(GetString(string_addr))
return [string_addr, string]
def getArgAddr(start_addr, regNum):
mipscondition = ["bn", "be" , "bg", "bl"]
scan_deep = 50
count = 0
reg = "$a" + str(regNum)
# try to get in the next
next_addr = Rfirst(start_addr)
if next_addr != BADADDR and reg == GetOpnd(next_addr, 0):
return next_addr
# try to get before
before_addr = RfirstB(start_addr)
while before_addr != BADADDR:
if reg == GetOpnd(before_addr, 0):
Mnemonics = GetMnem(before_addr)
if Mnemonics[0:2] in mipscondition:
pass
elif Mnemonics[0:1] == "j":
pass
else:
return before_addr
count = count + 1
if count > scan_deep:
break
before_addr = RfirstB(before_addr)
return BADADDR
def getArg(start_addr, regNum):
mipsmov = ["move", "lw", "li", "lb", "lui", "lhu", "lbu", "la"]
arg_addr = getArgAddr(start_addr, regNum)
if arg_addr != BADADDR:
Mnemonics = GetMnem(arg_addr)
if Mnemonics[0:3] == "add":
if GetOpnd(arg_addr, 2) == "":
arg = GetOpnd(arg_addr, 0) + "+" + GetOpnd(arg_addr, 1)
else:
arg = GetOpnd(arg_addr, 1) + "+" + GetOpnd(arg_addr, 2)
elif Mnemonics[0:3] == "sub":
if GetOpnd(arg_addr, 2) == "":
arg = GetOpnd(arg_addr, 0) + "-" + GetOpnd(arg_addr, 1)
else:
arg = GetOpnd(arg_addr, 1) + "-" + GetOpnd(arg_addr, 2)
elif Mnemonics in mipsmov:
arg = GetOpnd(arg_addr, 1)
else:
arg = GetDisasm(arg_addr).split("#")[0]
MakeComm(arg_addr, "addr: 0x%x " % start_addr + "-------> arg" + str((int(regNum)+1)) + " : " + arg)
return arg
else:
return "get fail"
def audit(func_name):
func_addr = getFuncAddr(func_name)
if func_addr == False:
return False
# get arg num and set table
if func_name in one_arg_function:
arg_num = 1
elif func_name in two_arg_function:
arg_num = 2
elif func_name in three_arg_function:
arg_num = 3
elif func_name in format_function_offset_dict:
arg_num = format_function_offset_dict[func_name] + 1
else:
print "The %s function didn't write in the describe arg num of function array,please add it to,such as add to `two_arg_function` arary" % func_name
return
# mispcall = ["jal", "jalr", "bal", "jr"]
table_head = ["func_name", "addr"]
for num in xrange(0,arg_num):
table_head.append("arg"+str(num+1))
if func_name in format_function_offset_dict:
table_head.append("format&value[string_addr, num of '%', fmt_arg...]")
table_head.append("local_buf_size")
table = PrettyTable(table_head)
# get first call
call_addr = RfirstB(func_addr)
while call_addr != BADADDR:
# set color ———— green (red=0x0000ff,blue = 0xff0000)
SetColor(call_addr, CIC_ITEM, 0x00ff00)
# set break point
# AddBpt(call_addr)
# DelBpt(call_addr)
# if you want to use condition
# SetBptCnd(ea, 'strstr(GetString(Dword(esp+4),-1, 0), "SAEXT.DLL") != -1')
Mnemonics = GetMnem(call_addr)
# print "Mnemonics : %s" % Mnemonics
# if Mnemonics in mispcall:
if Mnemonics[0:1] == "j" or Mnemonics[0:1] == "b":
# print func + " addr : 0x%x" % call_addr
if func_name in format_function_offset_dict:
info = auditFormat(call_addr, func_name, arg_num)
else:
info = auditAddr(call_addr, func_name, arg_num)
table.add_row(info)
call_addr = RnextB(func_addr, call_addr)
print table
# data_addr = DfirstB(func_addr)
# while data_addr != BADADDR:
# Mnemonics = GetMnem(data_addr)
# if DEBUG:
# print "Data Mnemonics : %s" % GetMnem(data_addr)
# print "Data addr : 0x %s" % data_addr
# data_addr = DnextB(func_addr, data_addr)
def auditAddr(call_addr, func_name, arg_num):
addr = "0x%x" % call_addr
ret_list = [func_name, addr]
# local buf size
local_buf_size = GetFunctionAttr(call_addr , FUNCATTR_FRSIZE)
if local_buf_size == BADADDR :
local_buf_size = "get fail"
else:
local_buf_size = "0x%x" % local_buf_size
# get arg
for num in xrange(0,arg_num):
ret_list.append(getArg(call_addr, num))
ret_list.append(local_buf_size)
return ret_list
def auditFormat(call_addr, func_name, arg_num):
addr = "0x%x" % call_addr
ret_list = [func_name, addr]
# local buf size
local_buf_size = GetFunctionAttr(call_addr , FUNCATTR_FRSIZE)
if local_buf_size == BADADDR :
local_buf_size = "get fail"
else:
local_buf_size = "0x%x" % local_buf_size
# get arg
for num in xrange(0,arg_num):
ret_list.append(getArg(call_addr, num))
arg_addr = getArgAddr(call_addr, format_function_offset_dict[func_name])
string_and_addr = getFormatString(arg_addr)
format_and_value = []
if string_and_addr == "get fail":
ret_list.append("get fail")
else:
string_addr = "0x%x" % string_and_addr[0]
format_and_value.append(string_addr)
string = string_and_addr[1]
fmt_num = string.count("%")
format_and_value.append(fmt_num)
# mips arg reg is from a0 to a3
if fmt_num > 3:
fmt_num = fmt_num - format_function_offset_dict[func_name] - 1
for num in xrange(0,fmt_num):
if arg_num + num > 3:
break
format_and_value.append(getArg(call_addr, arg_num + num))
ret_list.append(format_and_value)
# format_string = str(getFormatString(arg_addr)[1])
# print " format String: " + format_string
# ret_list.append([string_addr])
ret_list.append(local_buf_size)
return ret_list
def mipsAudit():
# the word create with figlet
start = '''
_ _ _ _ _
_ __ ___ (_)_ __ ___ / \ _ _ __| (_) |_
| '_ ` _ \| | '_ \/ __| / _ \| | | |/ _` | | __|
| | | | | | | |_) \__ \/ ___ \ |_| | (_| | | |_
|_| |_| |_|_| .__/|___/_/ \_\__,_|\__,_|_|\__|
|_|
code by giantbranch 2018.05
'''
print start
print "Auditing dangerous functions ......"
for func_name in dangerous_functions:
audit(func_name)
print "Auditing attention function ......"
for func_name in attention_function:
audit(func_name)
print "Auditing command execution function ......"
for func_name in command_execution_function:
audit(func_name)
print "Finished! Enjoy the result ~"
# 判断架构的代码,以后或许用得上
# info = idaapi.get_inf_structure()
# if info.is_64bit():
# bits = 64
# elif info.is_32bit():
# bits = 32
# else:
# bits = 16
# try:
# is_be = info.is_be()
# except:
# is_be = info.mf
# endian = "big" if is_be else "little"
# print 'Processor: {}, {}bit, {} endian'.format(info.procName, bits, endian)
# # Result: Processor: mipsr, 32bit, big endian
mipsAudit()
================================================
FILE: oldversion(py2)/prettytable.py
================================================
#!/usr/bin/env python
#
# Copyright (c) 2009-2013, Luke Maurits
# All rights reserved.
# With contributions from:
# * Chris Clark
# * Klein Stephane
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * The name of the author may not be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
__version__ = "0.7.2"
import copy
import csv
import random
import re
import sys
import textwrap
import itertools
import unicodedata
py3k = sys.version_info[0] >= 3
if py3k:
unicode = str
basestring = str
itermap = map
iterzip = zip
uni_chr = chr
from html.parser import HTMLParser
else:
itermap = itertools.imap
iterzip = itertools.izip
uni_chr = unichr
from HTMLParser import HTMLParser
if py3k and sys.version_info[1] >= 2:
from html import escape
else:
from cgi import escape
# hrule styles
FRAME = 0
ALL = 1
NONE = 2
HEADER = 3
# Table styles
DEFAULT = 10
MSWORD_FRIENDLY = 11
PLAIN_COLUMNS = 12
RANDOM = 20
_re = re.compile("\033\[[0-9;]*m")
def _get_size(text):
lines = text.split("\n")
height = len(lines)
width = max([_str_block_width(line) for line in lines])
return (width, height)
class PrettyTable(object):
def __init__(self, field_names=None, **kwargs):
"""Return a new PrettyTable instance
Arguments:
encoding - Unicode encoding scheme used to decode any encoded input
field_names - list or tuple of field names
fields - list or tuple of field names to include in displays
start - index of first data row to include in output
end - index of last data row to include in output PLUS ONE (list slice style)
header - print a header showing field names (True or False)
header_style - stylisation to apply to field names in header ("cap", "title", "upper", "lower" or None)
border - print a border around the table (True or False)
hrules - controls printing of horizontal rules after rows. Allowed values: FRAME, HEADER, ALL, NONE
vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE
int_format - controls formatting of integer data
float_format - controls formatting of floating point data
padding_width - number of spaces on either side of column data (only used if left and right paddings are None)
left_padding_width - number of spaces on left hand side of column data
right_padding_width - number of spaces on right hand side of column data
vertical_char - single character string used to draw vertical lines
horizontal_char - single character string used to draw horizontal lines
junction_char - single character string used to draw line junctions
sortby - name of field to sort rows by
sort_key - sorting key function, applied to data points before sorting
valign - default valign for each row (None, "t", "m" or "b")
reversesort - True or False to sort in descending or ascending order"""
self.encoding = kwargs.get("encoding", "UTF-8")
# Data
self._field_names = []
self._align = {}
self._valign = {}
self._max_width = {}
self._rows = []
if field_names:
self.field_names = field_names
else:
self._widths = []
# Options
self._options = "start end fields header border sortby reversesort sort_key attributes format hrules vrules".split()
self._options.extend("int_format float_format padding_width left_padding_width right_padding_width".split())
self._options.extend("vertical_char horizontal_char junction_char header_style valign xhtml print_empty".split())
for option in self._options:
if option in kwargs:
self._validate_option(option, kwargs[option])
else:
kwargs[option] = None
self._start = kwargs["start"] or 0
self._end = kwargs["end"] or None
self._fields = kwargs["fields"] or None
if kwargs["header"] in (True, False):
self._header = kwargs["header"]
else:
self._header = True
self._header_style = kwargs["header_style"] or None
if kwargs["border"] in (True, False):
self._border = kwargs["border"]
else:
self._border = True
self._hrules = kwargs["hrules"] or FRAME
self._vrules = kwargs["vrules"] or ALL
self._sortby = kwargs["sortby"] or None
if kwargs["reversesort"] in (True, False):
self._reversesort = kwargs["reversesort"]
else:
self._reversesort = False
self._sort_key = kwargs["sort_key"] or (lambda x: x)
self._int_format = kwargs["int_format"] or {}
self._float_format = kwargs["float_format"] or {}
self._padding_width = kwargs["padding_width"] or 1
self._left_padding_width = kwargs["left_padding_width"] or None
self._right_padding_width = kwargs["right_padding_width"] or None
self._vertical_char = kwargs["vertical_char"] or self._unicode("|")
self._horizontal_char = kwargs["horizontal_char"] or self._unicode("-")
self._junction_char = kwargs["junction_char"] or self._unicode("+")
if kwargs["print_empty"] in (True, False):
self._print_empty = kwargs["print_empty"]
else:
self._print_empty = True
self._format = kwargs["format"] or False
self._xhtml = kwargs["xhtml"] or False
self._attributes = kwargs["attributes"] or {}
def _unicode(self, value):
if not isinstance(value, basestring):
value = str(value)
if not isinstance(value, unicode):
value = unicode(value, self.encoding, "strict")
return value
def _justify(self, text, width, align):
excess = width - _str_block_width(text)
if align == "l":
return text + excess * " "
elif align == "r":
return excess * " " + text
else:
if excess % 2:
# Uneven padding
# Put more space on right if text is of odd length...
if _str_block_width(text) % 2:
return (excess//2)*" " + text + (excess//2 + 1)*" "
# and more space on left if text is of even length
else:
return (excess//2 + 1)*" " + text + (excess//2)*" "
# Why distribute extra space this way? To match the behaviour of
# the inbuilt str.center() method.
else:
# Equal padding on either side
return (excess//2)*" " + text + (excess//2)*" "
def __getattr__(self, name):
if name == "rowcount":
return len(self._rows)
elif name == "colcount":
if self._field_names:
return len(self._field_names)
elif self._rows:
return len(self._rows[0])
else:
return 0
else:
raise AttributeError(name)
def __getitem__(self, index):
new = PrettyTable()
new.field_names = self.field_names
for attr in self._options:
setattr(new, "_"+attr, getattr(self, "_"+attr))
setattr(new, "_align", getattr(self, "_align"))
if isinstance(index, slice):
for row in self._rows[index]:
new.add_row(row)
elif isinstance(index, int):
new.add_row(self._rows[index])
else:
raise Exception("Index %s is invalid, must be an integer or slice" % str(index))
return new
if py3k:
def __str__(self):
return self.__unicode__()
else:
def __str__(self):
return self.__unicode__().encode(self.encoding)
def __unicode__(self):
return self.get_string()
##############################
# ATTRIBUTE VALIDATORS #
##############################
# The method _validate_option is all that should be used elsewhere in the code base to validate options.
# It will call the appropriate validation method for that option. The individual validation methods should
# never need to be called directly (although nothing bad will happen if they *are*).
# Validation happens in TWO places.
# Firstly, in the property setters defined in the ATTRIBUTE MANAGMENT section.
# Secondly, in the _get_options method, where keyword arguments are mixed with persistent settings
def _validate_option(self, option, val):
if option in ("field_names"):
self._validate_field_names(val)
elif option in ("start", "end", "max_width", "padding_width", "left_padding_width", "right_padding_width", "format"):
self._validate_nonnegative_int(option, val)
elif option in ("sortby"):
self._validate_field_name(option, val)
elif option in ("sort_key"):
self._validate_function(option, val)
elif option in ("hrules"):
self._validate_hrules(option, val)
elif option in ("vrules"):
self._validate_vrules(option, val)
elif option in ("fields"):
self._validate_all_field_names(option, val)
elif option in ("header", "border", "reversesort", "xhtml", "print_empty"):
self._validate_true_or_false(option, val)
elif option in ("header_style"):
self._validate_header_style(val)
elif option in ("int_format"):
self._validate_int_format(option, val)
elif option in ("float_format"):
self._validate_float_format(option, val)
elif option in ("vertical_char", "horizontal_char", "junction_char"):
self._validate_single_char(option, val)
elif option in ("attributes"):
self._validate_attributes(option, val)
else:
raise Exception("Unrecognised option: %s!" % option)
def _validate_field_names(self, val):
# Check for appropriate length
if self._field_names:
try:
assert len(val) == len(self._field_names)
except AssertionError:
raise Exception("Field name list has incorrect number of values, (actual) %d!=%d (expected)" % (len(val), len(self._field_names)))
if self._rows:
try:
assert len(val) == len(self._rows[0])
except AssertionError:
raise Exception("Field name list has incorrect number of values, (actual) %d!=%d (expected)" % (len(val), len(self._rows[0])))
# Check for uniqueness
try:
assert len(val) == len(set(val))
except AssertionError:
raise Exception("Field names must be unique!")
def _validate_header_style(self, val):
try:
assert val in ("cap", "title", "upper", "lower", None)
except AssertionError:
raise Exception("Invalid header style, use cap, title, upper, lower or None!")
def _validate_align(self, val):
try:
assert val in ["l","c","r"]
except AssertionError:
raise Exception("Alignment %s is invalid, use l, c or r!" % val)
def _validate_valign(self, val):
try:
assert val in ["t","m","b",None]
except AssertionError:
raise Exception("Alignment %s is invalid, use t, m, b or None!" % val)
def _validate_nonnegative_int(self, name, val):
try:
assert int(val) >= 0
except AssertionError:
raise Exception("Invalid value for %s: %s!" % (name, self._unicode(val)))
def _validate_true_or_false(self, name, val):
try:
assert val in (True, False)
except AssertionError:
raise Exception("Invalid value for %s! Must be True or False." % name)
def _validate_int_format(self, name, val):
if val == "":
return
try:
assert type(val) in (str, unicode)
assert val.isdigit()
except AssertionError:
raise Exception("Invalid value for %s! Must be an integer format string." % name)
def _validate_float_format(self, name, val):
if val == "":
return
try:
assert type(val) in (str, unicode)
assert "." in val
bits = val.split(".")
assert len(bits) <= 2
assert bits[0] == "" or bits[0].isdigit()
assert bits[1] == "" or bits[1].isdigit()
except AssertionError:
raise Exception("Invalid value for %s! Must be a float format string." % name)
def _validate_function(self, name, val):
try:
assert hasattr(val, "__call__")
except AssertionError:
raise Exception("Invalid value for %s! Must be a function." % name)
def _validate_hrules(self, name, val):
try:
assert val in (ALL, FRAME, HEADER, NONE)
except AssertionError:
raise Exception("Invalid value for %s! Must be ALL, FRAME, HEADER or NONE." % name)
def _validate_vrules(self, name, val):
try:
assert val in (ALL, FRAME, NONE)
except AssertionError:
raise Exception("Invalid value for %s! Must be ALL, FRAME, or NONE." % name)
def _validate_field_name(self, name, val):
try:
assert (val in self._field_names) or (val is None)
except AssertionError:
raise Exception("Invalid field name: %s!" % val)
def _validate_all_field_names(self, name, val):
try:
for x in val:
self._validate_field_name(name, x)
except AssertionError:
raise Exception("fields must be a sequence of field names!")
def _validate_single_char(self, name, val):
try:
assert _str_block_width(val) == 1
except AssertionError:
raise Exception("Invalid value for %s! Must be a string of length 1." % name)
def _validate_attributes(self, name, val):
try:
assert isinstance(val, dict)
except AssertionError:
raise Exception("attributes must be a dictionary of name/value pairs!")
##############################
# ATTRIBUTE MANAGEMENT #
##############################
def _get_field_names(self):
return self._field_names
"""The names of the fields
Arguments:
fields - list or tuple of field names"""
def _set_field_names(self, val):
val = [self._unicode(x) for x in val]
self._validate_option("field_names", val)
if self._field_names:
old_names = self._field_names[:]
self._field_names = val
if self._align and old_names:
for old_name, new_name in zip(old_names, val):
self._align[new_name] = self._align[old_name]
for old_name in old_names:
if old_name not in self._align:
self._align.pop(old_name)
else:
for field in self._field_names:
self._align[field] = "c"
if self._valign and old_names:
for old_name, new_name in zip(old_names, val):
self._valign[new_name] = self._valign[old_name]
for old_name in old_names:
if old_name not in self._valign:
self._valign.pop(old_name)
else:
for field in self._field_names:
self._valign[field] = "t"
field_names = property(_get_field_names, _set_field_names)
def _get_align(self):
return self._align
def _set_align(self, val):
self._validate_align(val)
for field in self._field_names:
self._align[field] = val
align = property(_get_align, _set_align)
def _get_valign(self):
return self._valign
def _set_valign(self, val):
self._validate_valign(val)
for field in self._field_names:
self._valign[field] = val
valign = property(_get_valign, _set_valign)
def _get_max_width(self):
return self._max_width
def _set_max_width(self, val):
self._validate_option("max_width", val)
for field in self._field_names:
self._max_width[field] = val
max_width = property(_get_max_width, _set_max_width)
def _get_fields(self):
"""List or tuple of field names to include in displays
Arguments:
fields - list or tuple of field names to include in displays"""
return self._fields
def _set_fields(self, val):
self._validate_option("fields", val)
self._fields = val
fields = property(_get_fields, _set_fields)
def _get_start(self):
"""Start index of the range of rows to print
Arguments:
start - index of first data row to include in output"""
return self._start
def _set_start(self, val):
self._validate_option("start", val)
self._start = val
start = property(_get_start, _set_start)
def _get_end(self):
"""End index of the range of rows to print
Arguments:
end - index of last data row to include in output PLUS ONE (list slice style)"""
return self._end
def _set_end(self, val):
self._validate_option("end", val)
self._end = val
end = property(_get_end, _set_end)
def _get_sortby(self):
"""Name of field by which to sort rows
Arguments:
sortby - field name to sort by"""
return self._sortby
def _set_sortby(self, val):
self._validate_option("sortby", val)
self._sortby = val
sortby = property(_get_sortby, _set_sortby)
def _get_reversesort(self):
"""Controls direction of sorting (ascending vs descending)
Arguments:
reveresort - set to True to sort by descending order, or False to sort by ascending order"""
return self._reversesort
def _set_reversesort(self, val):
self._validate_option("reversesort", val)
self._reversesort = val
reversesort = property(_get_reversesort, _set_reversesort)
def _get_sort_key(self):
"""Sorting key function, applied to data points before sorting
Arguments:
sort_key - a function which takes one argument and returns something to be sorted"""
return self._sort_key
def _set_sort_key(self, val):
self._validate_option("sort_key", val)
self._sort_key = val
sort_key = property(_get_sort_key, _set_sort_key)
def _get_header(self):
"""Controls printing of table header with field names
Arguments:
header - print a header showing field names (True or False)"""
return self._header
def _set_header(self, val):
self._validate_option("header", val)
self._header = val
header = property(_get_header, _set_header)
def _get_header_style(self):
"""Controls stylisation applied to field names in header
Arguments:
header_style - stylisation to apply to field names in header ("cap", "title", "upper", "lower" or None)"""
return self._header_style
def _set_header_style(self, val):
self._validate_header_style(val)
self._header_style = val
header_style = property(_get_header_style, _set_header_style)
def _get_border(self):
"""Controls printing of border around table
Arguments:
border - print a border around the table (True or False)"""
return self._border
def _set_border(self, val):
self._validate_option("border", val)
self._border = val
border = property(_get_border, _set_border)
def _get_hrules(self):
"""Controls printing of horizontal rules after rows
Arguments:
hrules - horizontal rules style. Allowed values: FRAME, ALL, HEADER, NONE"""
return self._hrules
def _set_hrules(self, val):
self._validate_option("hrules", val)
self._hrules = val
hrules = property(_get_hrules, _set_hrules)
def _get_vrules(self):
"""Controls printing of vertical rules between columns
Arguments:
vrules - vertical rules style. Allowed values: FRAME, ALL, NONE"""
return self._vrules
def _set_vrules(self, val):
self._validate_option("vrules", val)
self._vrules = val
vrules = property(_get_vrules, _set_vrules)
def _get_int_format(self):
"""Controls formatting of integer data
Arguments:
int_format - integer format string"""
return self._int_format
def _set_int_format(self, val):
# self._validate_option("int_format", val)
for field in self._field_names:
self._int_format[field] = val
int_format = property(_get_int_format, _set_int_format)
def _get_float_format(self):
"""Controls formatting of floating point data
Arguments:
float_format - floating point format string"""
return self._float_format
def _set_float_format(self, val):
# self._validate_option("float_format", val)
for field in self._field_names:
self._float_format[field] = val
float_format = property(_get_float_format, _set_float_format)
def _get_padding_width(self):
"""The number of empty spaces between a column's edge and its content
Arguments:
padding_width - number of spaces, must be a positive integer"""
return self._padding_width
def _set_padding_width(self, val):
self._validate_option("padding_width", val)
self._padding_width = val
padding_width = property(_get_padding_width, _set_padding_width)
def _get_left_padding_width(self):
"""The number of empty spaces between a column's left edge and its content
Arguments:
left_padding - number of spaces, must be a positive integer"""
return self._left_padding_width
def _set_left_padding_width(self, val):
self._validate_option("left_padding_width", val)
self._left_padding_width = val
left_padding_width = property(_get_left_padding_width, _set_left_padding_width)
def _get_right_padding_width(self):
"""The number of empty spaces between a column's right edge and its content
Arguments:
right_padding - number of spaces, must be a positive integer"""
return self._right_padding_width
def _set_right_padding_width(self, val):
self._validate_option("right_padding_width", val)
self._right_padding_width = val
right_padding_width = property(_get_right_padding_width, _set_right_padding_width)
def _get_vertical_char(self):
"""The charcter used when printing table borders to draw vertical lines
Arguments:
vertical_char - single character string used to draw vertical lines"""
return self._vertical_char
def _set_vertical_char(self, val):
val = self._unicode(val)
self._validate_option("vertical_char", val)
self._vertical_char = val
vertical_char = property(_get_vertical_char, _set_vertical_char)
def _get_horizontal_char(self):
"""The charcter used when printing table borders to draw horizontal lines
Arguments:
horizontal_char - single character string used to draw horizontal lines"""
return self._horizontal_char
def _set_horizontal_char(self, val):
val = self._unicode(val)
self._validate_option("horizontal_char", val)
self._horizontal_char = val
horizontal_char = property(_get_horizontal_char, _set_horizontal_char)
def _get_junction_char(self):
"""The charcter used when printing table borders to draw line junctions
Arguments:
junction_char - single character string used to draw line junctions"""
return self._junction_char
def _set_junction_char(self, val):
val = self._unicode(val)
self._validate_option("vertical_char", val)
self._junction_char = val
junction_char = property(_get_junction_char, _set_junction_char)
def _get_format(self):
"""Controls whether or not HTML tables are formatted to match styling options
Arguments:
format - True or False"""
return self._format
def _set_format(self, val):
self._validate_option("format", val)
self._format = val
format = property(_get_format, _set_format)
def _get_print_empty(self):
"""Controls whether or not empty tables produce a header and frame or just an empty string
Arguments:
print_empty - True or False"""
return self._print_empty
def _set_print_empty(self, val):
self._validate_option("print_empty", val)
self._print_empty = val
print_empty = property(_get_print_empty, _set_print_empty)
def _get_attributes(self):
"""A dictionary of HTML attribute name/value pairs to be included in the tag when printing HTML
Arguments:
attributes - dictionary of attributes"""
return self._attributes
def _set_attributes(self, val):
self._validate_option("attributes", val)
self._attributes = val
attributes = property(_get_attributes, _set_attributes)
##############################
# OPTION MIXER #
##############################
def _get_options(self, kwargs):
options = {}
for option in self._options:
if option in kwargs:
self._validate_option(option, kwargs[option])
options[option] = kwargs[option]
else:
options[option] = getattr(self, "_"+option)
return options
##############################
# PRESET STYLE LOGIC #
##############################
def set_style(self, style):
if style == DEFAULT:
self._set_default_style()
elif style == MSWORD_FRIENDLY:
self._set_msword_style()
elif style == PLAIN_COLUMNS:
self._set_columns_style()
elif style == RANDOM:
self._set_random_style()
else:
raise Exception("Invalid pre-set style!")
def _set_default_style(self):
self.header = True
self.border = True
self._hrules = FRAME
self._vrules = ALL
self.padding_width = 1
self.left_padding_width = 1
self.right_padding_width = 1
self.vertical_char = "|"
self.horizontal_char = "-"
self.junction_char = "+"
def _set_msword_style(self):
self.header = True
self.border = True
self._hrules = NONE
self.padding_width = 1
self.left_padding_width = 1
self.right_padding_width = 1
self.vertical_char = "|"
def _set_columns_style(self):
self.header = True
self.border = False
self.padding_width = 1
self.left_padding_width = 0
self.right_padding_width = 8
def _set_random_style(self):
# Just for fun!
self.header = random.choice((True, False))
self.border = random.choice((True, False))
self._hrules = random.choice((ALL, FRAME, HEADER, NONE))
self._vrules = random.choice((ALL, FRAME, NONE))
self.left_padding_width = random.randint(0,5)
self.right_padding_width = random.randint(0,5)
self.vertical_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
self.horizontal_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
self.junction_char = random.choice("~!@#$%^&*()_+|-=\{}[];':\",./;<>?")
##############################
# DATA INPUT METHODS #
##############################
def add_row(self, row):
"""Add a row to the table
Arguments:
row - row of data, should be a list with as many elements as the table
has fields"""
if self._field_names and len(row) != len(self._field_names):
raise Exception("Row has incorrect number of values, (actual) %d!=%d (expected)" %(len(row),len(self._field_names)))
if not self._field_names:
self.field_names = [("Field %d" % (n+1)) for n in range(0,len(row))]
self._rows.append(list(row))
def del_row(self, row_index):
"""Delete a row to the table
Arguments:
row_index - The index of the row you want to delete. Indexing starts at 0."""
if row_index > len(self._rows)-1:
raise Exception("Cant delete row at index %d, table only has %d rows!" % (row_index, len(self._rows)))
del self._rows[row_index]
def add_column(self, fieldname, column, align="c", valign="t"):
"""Add a column to the table.
Arguments:
fieldname - name of the field to contain the new column of data
column - column of data, should be a list with as many elements as the
table has rows
align - desired alignment for this column - "l" for left, "c" for centre and "r" for right
valign - desired vertical alignment for new columns - "t" for top, "m" for middle and "b" for bottom"""
if len(self._rows) in (0, len(column)):
self._validate_align(align)
self._validate_valign(valign)
self._field_names.append(fieldname)
self._align[fieldname] = align
self._valign[fieldname] = valign
for i in range(0, len(column)):
if len(self._rows) < i+1:
self._rows.append([])
self._rows[i].append(column[i])
else:
raise Exception("Column length %d does not match number of rows %d!" % (len(column), len(self._rows)))
def clear_rows(self):
"""Delete all rows from the table but keep the current field names"""
self._rows = []
def clear(self):
"""Delete all rows and field names from the table, maintaining nothing but styling options"""
self._rows = []
self._field_names = []
self._widths = []
##############################
# MISC PUBLIC METHODS #
##############################
def copy(self):
return copy.deepcopy(self)
##############################
# MISC PRIVATE METHODS #
##############################
def _format_value(self, field, value):
if isinstance(value, int) and field in self._int_format:
value = self._unicode(("%%%sd" % self._int_format[field]) % value)
elif isinstance(value, float) and field in self._float_format:
value = self._unicode(("%%%sf" % self._float_format[field]) % value)
return self._unicode(value)
def _compute_widths(self, rows, options):
if options["header"]:
widths = [_get_size(field)[0] for field in self._field_names]
else:
widths = len(self.field_names) * [0]
for row in rows:
for index, value in enumerate(row):
fieldname = self.field_names[index]
if fieldname in self.max_width:
widths[index] = max(widths[index], min(_get_size(value)[0], self.max_width[fieldname]))
else:
widths[index] = max(widths[index], _get_size(value)[0])
self._widths = widths
def _get_padding_widths(self, options):
if options["left_padding_width"] is not None:
lpad = options["left_padding_width"]
else:
lpad = options["padding_width"]
if options["right_padding_width"] is not None:
rpad = options["right_padding_width"]
else:
rpad = options["padding_width"]
return lpad, rpad
def _get_rows(self, options):
"""Return only those data rows that should be printed, based on slicing and sorting.
Arguments:
options - dictionary of option settings."""
# Make a copy of only those rows in the slice range
rows = copy.deepcopy(self._rows[options["start"]:options["end"]])
# Sort if necessary
if options["sortby"]:
sortindex = self._field_names.index(options["sortby"])
# Decorate
rows = [[row[sortindex]]+row for row in rows]
# Sort
rows.sort(reverse=options["reversesort"], key=options["sort_key"])
# Undecorate
rows = [row[1:] for row in rows]
return rows
def _format_row(self, row, options):
return [self._format_value(field, value) for (field, value) in zip(self._field_names, row)]
def _format_rows(self, rows, options):
return [self._format_row(row, options) for row in rows]
##############################
# PLAIN TEXT STRING METHODS #
##############################
def get_string(self, **kwargs):
"""Return string representation of table in current state.
Arguments:
start - index of first data row to include in output
end - index of last data row to include in output PLUS ONE (list slice style)
fields - names of fields (columns) to include
header - print a header showing field names (True or False)
border - print a border around the table (True or False)
hrules - controls printing of horizontal rules after rows. Allowed values: ALL, FRAME, HEADER, NONE
vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE
int_format - controls formatting of integer data
float_format - controls formatting of floating point data
padding_width - number of spaces on either side of column data (only used if left and right paddings are None)
left_padding_width - number of spaces on left hand side of column data
right_padding_width - number of spaces on right hand side of column data
vertical_char - single character string used to draw vertical lines
horizontal_char - single character string used to draw horizontal lines
junction_char - single character string used to draw line junctions
sortby - name of field to sort rows by
sort_key - sorting key function, applied to data points before sorting
reversesort - True or False to sort in descending or ascending order
print empty - if True, stringify just the header for an empty table, if False return an empty string """
options = self._get_options(kwargs)
lines = []
# Don't think too hard about an empty table
# Is this the desired behaviour? Maybe we should still print the header?
if self.rowcount == 0 and (not options["print_empty"] or not options["border"]):
return ""
# Get the rows we need to print, taking into account slicing, sorting, etc.
rows = self._get_rows(options)
# Turn all data in all rows into Unicode, formatted as desired
formatted_rows = self._format_rows(rows, options)
# Compute column widths
self._compute_widths(formatted_rows, options)
# Add header or top of border
self._hrule = self._stringify_hrule(options)
if options["header"]:
lines.append(self._stringify_header(options))
elif options["border"] and options["hrules"] in (ALL, FRAME):
lines.append(self._hrule)
# Add rows
for row in formatted_rows:
lines.append(self._stringify_row(row, options))
# Add bottom of border
if options["border"] and options["hrules"] == FRAME:
lines.append(self._hrule)
return self._unicode("\n").join(lines)
def _stringify_hrule(self, options):
if not options["border"]:
return ""
lpad, rpad = self._get_padding_widths(options)
if options['vrules'] in (ALL, FRAME):
bits = [options["junction_char"]]
else:
bits = [options["horizontal_char"]]
# For tables with no data or fieldnames
if not self._field_names:
bits.append(options["junction_char"])
return "".join(bits)
for field, width in zip(self._field_names, self._widths):
if options["fields"] and field not in options["fields"]:
continue
bits.append((width+lpad+rpad)*options["horizontal_char"])
if options['vrules'] == ALL:
bits.append(options["junction_char"])
else:
bits.append(options["horizontal_char"])
if options["vrules"] == FRAME:
bits.pop()
bits.append(options["junction_char"])
return "".join(bits)
def _stringify_header(self, options):
bits = []
lpad, rpad = self._get_padding_widths(options)
if options["border"]:
if options["hrules"] in (ALL, FRAME):
bits.append(self._hrule)
bits.append("\n")
if options["vrules"] in (ALL, FRAME):
bits.append(options["vertical_char"])
else:
bits.append(" ")
# For tables with no data or field names
if not self._field_names:
if options["vrules"] in (ALL, FRAME):
bits.append(options["vertical_char"])
else:
bits.append(" ")
for field, width, in zip(self._field_names, self._widths):
if options["fields"] and field not in options["fields"]:
continue
if self._header_style == "cap":
fieldname = field.capitalize()
elif self._header_style == "title":
fieldname = field.title()
elif self._header_style == "upper":
fieldname = field.upper()
elif self._header_style == "lower":
fieldname = field.lower()
else:
fieldname = field
bits.append(" " * lpad + self._justify(fieldname, width, self._align[field]) + " " * rpad)
if options["border"]:
if options["vrules"] == ALL:
bits.append(options["vertical_char"])
else:
bits.append(" ")
# If vrules is FRAME, then we just appended a space at the end
# of the last field, when we really want a vertical character
if options["border"] and options["vrules"] == FRAME:
bits.pop()
bits.append(options["vertical_char"])
if options["border"] and options["hrules"] != NONE:
bits.append("\n")
bits.append(self._hrule)
return "".join(bits)
def _stringify_row(self, row, options):
for index, field, value, width, in zip(range(0,len(row)), self._field_names, row, self._widths):
# Enforce max widths
lines = value.split("\n")
new_lines = []
for line in lines:
if _str_block_width(line) > width:
line = textwrap.fill(line, width)
new_lines.append(line)
lines = new_lines
value = "\n".join(lines)
row[index] = value
row_height = 0
for c in row:
h = _get_size(c)[1]
if h > row_height:
row_height = h
bits = []
lpad, rpad = self._get_padding_widths(options)
for y in range(0, row_height):
bits.append([])
if options["border"]:
if options["vrules"] in (ALL, FRAME):
bits[y].append(self.vertical_char)
else:
bits[y].append(" ")
for field, value, width, in zip(self._field_names, row, self._widths):
valign = self._valign[field]
lines = value.split("\n")
dHeight = row_height - len(lines)
if dHeight:
if valign == "m":
lines = [""] * int(dHeight / 2) + lines + [""] * (dHeight - int(dHeight / 2))
elif valign == "b":
lines = [""] * dHeight + lines
else:
lines = lines + [""] * dHeight
y = 0
for l in lines:
if options["fields"] and field not in options["fields"]:
continue
bits[y].append(" " * lpad + self._justify(l, width, self._align[field]) + " " * rpad)
if options["border"]:
if options["vrules"] == ALL:
bits[y].append(self.vertical_char)
else:
bits[y].append(" ")
y += 1
# If vrules is FRAME, then we just appended a space at the end
# of the last field, when we really want a vertical character
for y in range(0, row_height):
if options["border"] and options["vrules"] == FRAME:
bits[y].pop()
bits[y].append(options["vertical_char"])
if options["border"] and options["hrules"]== ALL:
bits[row_height-1].append("\n")
bits[row_height-1].append(self._hrule)
for y in range(0, row_height):
bits[y] = "".join(bits[y])
return "\n".join(bits)
##############################
# HTML STRING METHODS #
##############################
def get_html_string(self, **kwargs):
"""Return string representation of HTML formatted version of table in current state.
Arguments:
start - index of first data row to include in output
end - index of last data row to include in output PLUS ONE (list slice style)
fields - names of fields (columns) to include
header - print a header showing field names (True or False)
border - print a border around the table (True or False)
hrules - controls printing of horizontal rules after rows. Allowed values: ALL, FRAME, HEADER, NONE
vrules - controls printing of vertical rules between columns. Allowed values: FRAME, ALL, NONE
int_format - controls formatting of integer data
float_format - controls formatting of floating point data
padding_width - number of spaces on either side of column data (only used if left and right paddings are None)
left_padding_width - number of spaces on left hand side of column data
right_padding_width - number of spaces on right hand side of column data
sortby - name of field to sort rows by
sort_key - sorting key function, applied to data points before sorting
attributes - dictionary of name/value pairs to include as HTML attributes in the tag
xhtml - print
tags if True,
tags if false"""
options = self._get_options(kwargs)
if options["format"]:
string = self._get_formatted_html_string(options)
else:
string = self._get_simple_html_string(options)
return string
def _get_simple_html_string(self, options):
lines = []
if options["xhtml"]:
linebreak = "
"
else:
linebreak = "
"
open_tag = []
open_tag.append("")
lines.append("".join(open_tag))
# Headers
if options["header"]:
lines.append(" ")
for field in self._field_names:
if options["fields"] and field not in options["fields"]:
continue
lines.append(" | %s | " % escape(field).replace("\n", linebreak))
lines.append("
")
# Data
rows = self._get_rows(options)
formatted_rows = self._format_rows(rows, options)
for row in formatted_rows:
lines.append(" ")
for field, datum in zip(self._field_names, row):
if options["fields"] and field not in options["fields"]:
continue
lines.append(" | %s | " % escape(datum).replace("\n", linebreak))
lines.append("
")
lines.append("
")
return self._unicode("\n").join(lines)
def _get_formatted_html_string(self, options):
lines = []
lpad, rpad = self._get_padding_widths(options)
if options["xhtml"]:
linebreak = "
"
else:
linebreak = "
"
open_tag = []
open_tag.append("")
lines.append("".join(open_tag))
# Headers
if options["header"]:
lines.append(" ")
for field in self._field_names:
if options["fields"] and field not in options["fields"]:
continue
lines.append(" | %s | " % (lpad, rpad, escape(field).replace("\n", linebreak)))
lines.append("
")
# Data
rows = self._get_rows(options)
formatted_rows = self._format_rows(rows, options)
aligns = []
valigns = []
for field in self._field_names:
aligns.append({ "l" : "left", "r" : "right", "c" : "center" }[self._align[field]])
valigns.append({"t" : "top", "m" : "middle", "b" : "bottom"}[self._valign[field]])
for row in formatted_rows:
lines.append(" ")
for field, datum, align, valign in zip(self._field_names, row, aligns, valigns):
if options["fields"] and field not in options["fields"]:
continue
lines.append(" | %s | " % (lpad, rpad, align, valign, escape(datum).replace("\n", linebreak)))
lines.append("
")
lines.append("
")
return self._unicode("\n").join(lines)
##############################
# UNICODE WIDTH FUNCTIONS #
##############################
def _char_block_width(char):
# Basic Latin, which is probably the most common case
#if char in xrange(0x0021, 0x007e):
#if char >= 0x0021 and char <= 0x007e:
if 0x0021 <= char <= 0x007e:
return 1
# Chinese, Japanese, Korean (common)
if 0x4e00 <= char <= 0x9fff:
return 2
# Hangul
if 0xac00 <= char <= 0xd7af:
return 2
# Combining?
if unicodedata.combining(uni_chr(char)):
return 0
# Hiragana and Katakana
if 0x3040 <= char <= 0x309f or 0x30a0 <= char <= 0x30ff:
return 2
# Full-width Latin characters
if 0xff01 <= char <= 0xff60:
return 2
# CJK punctuation
if 0x3000 <= char <= 0x303e:
return 2
# Backspace and delete
if char in (0x0008, 0x007f):
return -1
# Other control characters
elif char in (0x0000, 0x001f):
return 0
# Take a guess
return 1
def _str_block_width(val):
return sum(itermap(_char_block_width, itermap(ord, _re.sub("", val))))
##############################
# TABLE FACTORIES #
##############################
def from_csv(fp, field_names = None, **kwargs):
dialect = csv.Sniffer().sniff(fp.read(1024))
fp.seek(0)
reader = csv.reader(fp, dialect)
table = PrettyTable(**kwargs)
if field_names:
table.field_names = field_names
else:
if py3k:
table.field_names = [x.strip() for x in next(reader)]
else:
table.field_names = [x.strip() for x in reader.next()]
for row in reader:
table.add_row([x.strip() for x in row])
return table
def from_db_cursor(cursor, **kwargs):
if cursor.description:
table = PrettyTable(**kwargs)
table.field_names = [col[0] for col in cursor.description]
for row in cursor.fetchall():
table.add_row(row)
return table
class TableHandler(HTMLParser):
def __init__(self, **kwargs):
HTMLParser.__init__(self)
self.kwargs = kwargs
self.tables = []
self.last_row = []
self.rows = []
self.max_row_width = 0
self.active = None
self.last_content = ""
self.is_last_row_header = False
def handle_starttag(self,tag, attrs):
self.active = tag
if tag == "th":
self.is_last_row_header = True
def handle_endtag(self,tag):
if tag in ["th", "td"]:
stripped_content = self.last_content.strip()
self.last_row.append(stripped_content)
if tag == "tr":
self.rows.append(
(self.last_row, self.is_last_row_header))
self.max_row_width = max(self.max_row_width, len(self.last_row))
self.last_row = []
self.is_last_row_header = False
if tag == "table":
table = self.generate_table(self.rows)
self.tables.append(table)
self.rows = []
self.last_content = " "
self.active = None
def handle_data(self, data):
self.last_content += data
def generate_table(self, rows):
"""
Generates from a list of rows a PrettyTable object.
"""
table = PrettyTable(**self.kwargs)
for row in self.rows:
if len(row[0]) < self.max_row_width:
appends = self.max_row_width - len(row[0])
for i in range(1,appends):
row[0].append("-")
if row[1] == True:
self.make_fields_unique(row[0])
table.field_names = row[0]
else:
table.add_row(row[0])
return table
def make_fields_unique(self, fields):
"""
iterates over the row and make each field unique
"""
for i in range(0, len(fields)):
for j in range(i+1, len(fields)):
if fields[i] == fields[j]:
fields[j] += "'"
def from_html(html_code, **kwargs):
"""
Generates a list of PrettyTables from a string of HTML code. Each in
the HTML becomes one PrettyTable object.
"""
parser = TableHandler(**kwargs)
parser.feed(html_code)
return parser.tables
def from_html_one(html_code, **kwargs):
"""
Generates a PrettyTables from a string of HTML code which contains only a
single
"""
tables = from_html(html_code, **kwargs)
try:
assert len(tables) == 1
except AssertionError:
raise Exception("More than one in provided HTML code! Use from_html instead.")
return tables[0]
##############################
# MAIN (TEST FUNCTION) #
##############################
def main():
x = PrettyTable(["City name", "Area", "Population", "Annual Rainfall"])
x.sortby = "Population"
x.reversesort = True
x.int_format["Area"] = "04d"
x.float_format = "6.1f"
x.align["City name"] = "l" # Left align city names
x.add_row(["Adelaide", 1295, 1158259, 600.5])
x.add_row(["Brisbane", 5905, 1857594, 1146.4])
x.add_row(["Darwin", 112, 120900, 1714.7])
x.add_row(["Hobart", 1357, 205556, 619.5])
x.add_row(["Sydney", 2058, 4336374, 1214.8])
x.add_row(["Melbourne", 1566, 3806092, 646.9])
x.add_row(["Perth", 5386, 1554769, 869.4])
print(x)
if __name__ == "__main__":
main()