[
  {
    "path": ".gitignore",
    "content": "__pycache__/\ndata/\n*.pyc\n"
  },
  {
    "path": "decoder.py",
    "content": "import argparse\nimport math\nimport numpy as np\nfrom utils import *\nfrom scipy import fftpack\nfrom PIL import Image\n\n\nclass JPEGFileReader:\n    TABLE_SIZE_BITS = 16\n    BLOCKS_COUNT_BITS = 32\n\n    DC_CODE_LENGTH_BITS = 4\n    CATEGORY_BITS = 4\n\n    AC_CODE_LENGTH_BITS = 8\n    RUN_LENGTH_BITS = 4\n    SIZE_BITS = 4\n\n    def __init__(self, filepath):\n        self.__file = open(filepath, 'r')\n\n    def read_int(self, size):\n        if size == 0:\n            return 0\n\n        # the most significant bit indicates the sign of the number\n        bin_num = self.__read_str(size)\n        if bin_num[0] == '1':\n            return self.__int2(bin_num)\n        else:\n            return self.__int2(binstr_flip(bin_num)) * -1\n\n    def read_dc_table(self):\n        table = dict()\n\n        table_size = self.__read_uint(self.TABLE_SIZE_BITS)\n        for _ in range(table_size):\n            category = self.__read_uint(self.CATEGORY_BITS)\n            code_length = self.__read_uint(self.DC_CODE_LENGTH_BITS)\n            code = self.__read_str(code_length)\n            table[code] = category\n        return table\n\n    def read_ac_table(self):\n        table = dict()\n\n        table_size = self.__read_uint(self.TABLE_SIZE_BITS)\n        for _ in range(table_size):\n            run_length = self.__read_uint(self.RUN_LENGTH_BITS)\n            size = self.__read_uint(self.SIZE_BITS)\n            code_length = self.__read_uint(self.AC_CODE_LENGTH_BITS)\n            code = self.__read_str(code_length)\n            table[code] = (run_length, size)\n        return table\n\n    def read_blocks_count(self):\n        return self.__read_uint(self.BLOCKS_COUNT_BITS)\n\n    def read_huffman_code(self, table):\n        prefix = ''\n        # TODO: break the loop if __read_char is not returing new char\n        while prefix not in table:\n            prefix += self.__read_char()\n        return table[prefix]\n\n    def __read_uint(self, size):\n        if size <= 0:\n            raise ValueError(\"size of unsigned int should be greater than 0\")\n        return self.__int2(self.__read_str(size))\n\n    def __read_str(self, length):\n        return self.__file.read(length)\n\n    def __read_char(self):\n        return self.__read_str(1)\n\n    def __int2(self, bin_num):\n        return int(bin_num, 2)\n\n\ndef read_image_file(filepath):\n    reader = JPEGFileReader(filepath)\n\n    tables = dict()\n    for table_name in ['dc_y', 'ac_y', 'dc_c', 'ac_c']:\n        if 'dc' in table_name:\n            tables[table_name] = reader.read_dc_table()\n        else:\n            tables[table_name] = reader.read_ac_table()\n\n    blocks_count = reader.read_blocks_count()\n\n    dc = np.empty((blocks_count, 3), dtype=np.int32)\n    ac = np.empty((blocks_count, 63, 3), dtype=np.int32)\n\n    for block_index in range(blocks_count):\n        for component in range(3):\n            dc_table = tables['dc_y'] if component == 0 else tables['dc_c']\n            ac_table = tables['ac_y'] if component == 0 else tables['ac_c']\n\n            category = reader.read_huffman_code(dc_table)\n            dc[block_index, component] = reader.read_int(category)\n\n            cells_count = 0\n\n            # TODO: try to make reading AC coefficients better\n            while cells_count < 63:\n                run_length, size = reader.read_huffman_code(ac_table)\n\n                if (run_length, size) == (0, 0):\n                    while cells_count < 63:\n                        ac[block_index, cells_count, component] = 0\n                        cells_count += 1\n                else:\n                    for i in range(run_length):\n                        ac[block_index, cells_count, component] = 0\n                        cells_count += 1\n                    if size == 0:\n                        ac[block_index, cells_count, component] = 0\n                    else:\n                        value = reader.read_int(size)\n                        ac[block_index, cells_count, component] = value\n                    cells_count += 1\n\n    return dc, ac, tables, blocks_count\n\n\ndef zigzag_to_block(zigzag):\n    # assuming that the width and the height of the block are equal\n    rows = cols = int(math.sqrt(len(zigzag)))\n\n    if rows * cols != len(zigzag):\n        raise ValueError(\"length of zigzag should be a perfect square\")\n\n    block = np.empty((rows, cols), np.int32)\n\n    for i, point in enumerate(zigzag_points(rows, cols)):\n        block[point] = zigzag[i]\n\n    return block\n\n\ndef dequantize(block, component):\n    q = load_quantization_table(component)\n    return block * q\n\n\ndef idct_2d(image):\n    return fftpack.idct(fftpack.idct(image.T, norm='ortho').T, norm='ortho')\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input\", help=\"path to the input image\")\n    args = parser.parse_args()\n\n    dc, ac, tables, blocks_count = read_image_file(args.input)\n\n    # assuming that the block is a 8x8 square\n    block_side = 8\n\n    # assuming that the image height and width are equal\n    image_side = int(math.sqrt(blocks_count)) * block_side\n\n    blocks_per_line = image_side // block_side\n\n    npmat = np.empty((image_side, image_side, 3), dtype=np.uint8)\n\n    for block_index in range(blocks_count):\n        i = block_index // blocks_per_line * block_side\n        j = block_index % blocks_per_line * block_side\n\n        for c in range(3):\n            zigzag = [dc[block_index, c]] + list(ac[block_index, :, c])\n            quant_matrix = zigzag_to_block(zigzag)\n            dct_matrix = dequantize(quant_matrix, 'lum' if c == 0 else 'chrom')\n            block = idct_2d(dct_matrix)\n            npmat[i:i+8, j:j+8, c] = block + 128\n\n    image = Image.fromarray(npmat, 'YCbCr')\n    image = image.convert('RGB')\n    image.show()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "encoder.py",
    "content": "import argparse\nimport os\nimport math\nimport numpy as np\nfrom utils import *\nfrom scipy import fftpack\nfrom PIL import Image\nfrom huffman import HuffmanTree\n\n\ndef quantize(block, component):\n    q = load_quantization_table(component)\n    return (block / q).round().astype(np.int32)\n\n\ndef block_to_zigzag(block):\n    return np.array([block[point] for point in zigzag_points(*block.shape)])\n\n\ndef dct_2d(image):\n    return fftpack.dct(fftpack.dct(image.T, norm='ortho').T, norm='ortho')\n\n\ndef run_length_encode(arr):\n    # determine where the sequence is ending prematurely\n    last_nonzero = -1\n    for i, elem in enumerate(arr):\n        if elem != 0:\n            last_nonzero = i\n\n    # each symbol is a (RUNLENGTH, SIZE) tuple\n    symbols = []\n\n    # values are binary representations of array elements using SIZE bits\n    values = []\n\n    run_length = 0\n\n    for i, elem in enumerate(arr):\n        if i > last_nonzero:\n            symbols.append((0, 0))\n            values.append(int_to_binstr(0))\n            break\n        elif elem == 0 and run_length < 15:\n            run_length += 1\n        else:\n            size = bits_required(elem)\n            symbols.append((run_length, size))\n            values.append(int_to_binstr(elem))\n            run_length = 0\n    return symbols, values\n\n\ndef write_to_file(filepath, dc, ac, blocks_count, tables):\n    try:\n        f = open(filepath, 'w')\n    except FileNotFoundError as e:\n        raise FileNotFoundError(\n                \"No such directory: {}\".format(\n                    os.path.dirname(filepath))) from e\n\n    for table_name in ['dc_y', 'ac_y', 'dc_c', 'ac_c']:\n\n        # 16 bits for 'table_size'\n        f.write(uint_to_binstr(len(tables[table_name]), 16))\n\n        for key, value in tables[table_name].items():\n            if table_name in {'dc_y', 'dc_c'}:\n                # 4 bits for the 'category'\n                # 4 bits for 'code_length'\n                # 'code_length' bits for 'huffman_code'\n                f.write(uint_to_binstr(key, 4))\n                f.write(uint_to_binstr(len(value), 4))\n                f.write(value)\n            else:\n                # 4 bits for 'run_length'\n                # 4 bits for 'size'\n                # 8 bits for 'code_length'\n                # 'code_length' bits for 'huffman_code'\n                f.write(uint_to_binstr(key[0], 4))\n                f.write(uint_to_binstr(key[1], 4))\n                f.write(uint_to_binstr(len(value), 8))\n                f.write(value)\n\n    # 32 bits for 'blocks_count'\n    f.write(uint_to_binstr(blocks_count, 32))\n\n    for b in range(blocks_count):\n        for c in range(3):\n            category = bits_required(dc[b, c])\n            symbols, values = run_length_encode(ac[b, :, c])\n\n            dc_table = tables['dc_y'] if c == 0 else tables['dc_c']\n            ac_table = tables['ac_y'] if c == 0 else tables['ac_c']\n\n            f.write(dc_table[category])\n            f.write(int_to_binstr(dc[b, c]))\n\n            for i in range(len(symbols)):\n                f.write(ac_table[tuple(symbols[i])])\n                f.write(values[i])\n    f.close()\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input\", help=\"path to the input image\")\n    parser.add_argument(\"output\", help=\"path to the output image\")\n    args = parser.parse_args()\n\n    input_file = args.input\n    output_file = args.output\n\n    image = Image.open(input_file)\n    ycbcr = image.convert('YCbCr')\n\n    npmat = np.array(ycbcr, dtype=np.uint8)\n\n    rows, cols = npmat.shape[0], npmat.shape[1]\n\n    # block size: 8x8\n    if rows % 8 == cols % 8 == 0:\n        blocks_count = rows // 8 * cols // 8\n    else:\n        raise ValueError((\"the width and height of the image \"\n                          \"should both be mutiples of 8\"))\n\n    # dc is the top-left cell of the block, ac are all the other cells\n    dc = np.empty((blocks_count, 3), dtype=np.int32)\n    ac = np.empty((blocks_count, 63, 3), dtype=np.int32)\n\n    for i in range(0, rows, 8):\n        for j in range(0, cols, 8):\n            try:\n                block_index += 1\n            except NameError:\n                block_index = 0\n\n            for k in range(3):\n                # split 8x8 block and center the data range on zero\n                # [0, 255] --> [-128, 127]\n                block = npmat[i:i+8, j:j+8, k] - 128\n\n                dct_matrix = dct_2d(block)\n                quant_matrix = quantize(dct_matrix,\n                                        'lum' if k == 0 else 'chrom')\n                zz = block_to_zigzag(quant_matrix)\n\n                dc[block_index, k] = zz[0]\n                ac[block_index, :, k] = zz[1:]\n\n    H_DC_Y = HuffmanTree(np.vectorize(bits_required)(dc[:, 0]))\n    H_DC_C = HuffmanTree(np.vectorize(bits_required)(dc[:, 1:].flat))\n    H_AC_Y = HuffmanTree(\n            flatten(run_length_encode(ac[i, :, 0])[0]\n                    for i in range(blocks_count)))\n    H_AC_C = HuffmanTree(\n            flatten(run_length_encode(ac[i, :, j])[0]\n                    for i in range(blocks_count) for j in [1, 2]))\n\n    tables = {'dc_y': H_DC_Y.value_to_bitstring_table(),\n              'ac_y': H_AC_Y.value_to_bitstring_table(),\n              'dc_c': H_DC_C.value_to_bitstring_table(),\n              'ac_c': H_AC_C.value_to_bitstring_table()}\n\n    write_to_file(output_file, dc, ac, blocks_count, tables)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "huffman.py",
    "content": "from queue import PriorityQueue\n\n\nclass HuffmanTree:\n\n    class __Node:\n        def __init__(self, value, freq, left_child, right_child):\n            self.value = value\n            self.freq = freq\n            self.left_child = left_child\n            self.right_child = right_child\n\n        @classmethod\n        def init_leaf(self, value, freq):\n            return self(value, freq, None, None)\n\n        @classmethod\n        def init_node(self, left_child, right_child):\n            freq = left_child.freq + right_child.freq\n            return self(None, freq, left_child, right_child)\n\n        def is_leaf(self):\n            return self.value is not None\n\n        def __eq__(self, other):\n            stup = self.value, self.freq, self.left_child, self.right_child\n            otup = other.value, other.freq, other.left_child, other.right_child\n            return stup == otup\n\n        def __nq__(self, other):\n            return not (self == other)\n\n        def __lt__(self, other):\n            return self.freq < other.freq\n\n        def __le__(self, other):\n            return self.freq < other.freq or self.freq == other.freq\n\n        def __gt__(self, other):\n            return not (self <= other)\n\n        def __ge__(self, other):\n            return not (self < other)\n\n    def __init__(self, arr):\n        q = PriorityQueue()\n\n        # calculate frequencies and insert them into a priority queue\n        for val, freq in self.__calc_freq(arr).items():\n            q.put(self.__Node.init_leaf(val, freq))\n\n        while q.qsize() >= 2:\n            u = q.get()\n            v = q.get()\n\n            q.put(self.__Node.init_node(u, v))\n\n        self.__root = q.get()\n\n        # dictionaries to store huffman table\n        self.__value_to_bitstring = dict()\n\n    def value_to_bitstring_table(self):\n        if len(self.__value_to_bitstring.keys()) == 0:\n            self.__create_huffman_table()\n        return self.__value_to_bitstring\n\n    def __create_huffman_table(self):\n        def tree_traverse(current_node, bitstring=''):\n            if current_node is None:\n                return\n            if current_node.is_leaf():\n                self.__value_to_bitstring[current_node.value] = bitstring\n                return\n            tree_traverse(current_node.left_child, bitstring + '0')\n            tree_traverse(current_node.right_child, bitstring + '1')\n\n        tree_traverse(self.__root)\n\n    def __calc_freq(self, arr):\n        freq_dict = dict()\n        for elem in arr:\n            if elem in freq_dict:\n                freq_dict[elem] += 1\n            else:\n                freq_dict[elem] = 1\n        return freq_dict\n"
  },
  {
    "path": "utils.py",
    "content": "import numpy as np\n\n\ndef load_quantization_table(component):\n    # Quantization Table for: Photoshop - (Save For Web 080)\n    # (http://www.impulseadventure.com/photo/jpeg-quantization.html)\n    if component == 'lum':\n        q = np.array([[2, 2, 2, 2, 3, 4, 5, 6],\n                      [2, 2, 2, 2, 3, 4, 5, 6],\n                      [2, 2, 2, 2, 4, 5, 7, 9],\n                      [2, 2, 2, 4, 5, 7, 9, 12],\n                      [3, 3, 4, 5, 8, 10, 12, 12],\n                      [4, 4, 5, 7, 10, 12, 12, 12],\n                      [5, 5, 7, 9, 12, 12, 12, 12],\n                      [6, 6, 9, 12, 12, 12, 12, 12]])\n    elif component == 'chrom':\n        q = np.array([[3, 3, 5, 9, 13, 15, 15, 15],\n                      [3, 4, 6, 11, 14, 12, 12, 12],\n                      [5, 6, 9, 14, 12, 12, 12, 12],\n                      [9, 11, 14, 12, 12, 12, 12, 12],\n                      [13, 14, 12, 12, 12, 12, 12, 12],\n                      [15, 12, 12, 12, 12, 12, 12, 12],\n                      [15, 12, 12, 12, 12, 12, 12, 12],\n                      [15, 12, 12, 12, 12, 12, 12, 12]])\n    else:\n        raise ValueError((\n            \"component should be either 'lum' or 'chrom', \"\n            \"but '{comp}' was found\").format(comp=component))\n\n    return q\n\n\ndef zigzag_points(rows, cols):\n    # constants for directions\n    UP, DOWN, RIGHT, LEFT, UP_RIGHT, DOWN_LEFT = range(6)\n\n    # move the point in different directions\n    def move(direction, point):\n        return {\n            UP: lambda point: (point[0] - 1, point[1]),\n            DOWN: lambda point: (point[0] + 1, point[1]),\n            LEFT: lambda point: (point[0], point[1] - 1),\n            RIGHT: lambda point: (point[0], point[1] + 1),\n            UP_RIGHT: lambda point: move(UP, move(RIGHT, point)),\n            DOWN_LEFT: lambda point: move(DOWN, move(LEFT, point))\n        }[direction](point)\n\n    # return true if point is inside the block bounds\n    def inbounds(point):\n        return 0 <= point[0] < rows and 0 <= point[1] < cols\n\n    # start in the top-left cell\n    point = (0, 0)\n\n    # True when moving up-right, False when moving down-left\n    move_up = True\n\n    for i in range(rows * cols):\n        yield point\n        if move_up:\n            if inbounds(move(UP_RIGHT, point)):\n                point = move(UP_RIGHT, point)\n            else:\n                move_up = False\n                if inbounds(move(RIGHT, point)):\n                    point = move(RIGHT, point)\n                else:\n                    point = move(DOWN, point)\n        else:\n            if inbounds(move(DOWN_LEFT, point)):\n                point = move(DOWN_LEFT, point)\n            else:\n                move_up = True\n                if inbounds(move(DOWN, point)):\n                    point = move(DOWN, point)\n                else:\n                    point = move(RIGHT, point)\n\n\ndef bits_required(n):\n    n = abs(n)\n    result = 0\n    while n > 0:\n        n >>= 1\n        result += 1\n    return result\n\n\ndef binstr_flip(binstr):\n    # check if binstr is a binary string\n    if not set(binstr).issubset('01'):\n        raise ValueError(\"binstr should have only '0's and '1's\")\n    return ''.join(map(lambda c: '0' if c == '1' else '1', binstr))\n\n\ndef uint_to_binstr(number, size):\n    return bin(number)[2:][-size:].zfill(size)\n\n\ndef int_to_binstr(n):\n    if n == 0:\n        return ''\n\n    binstr = bin(abs(n))[2:]\n\n    # change every 0 to 1 and vice verse when n is negative\n    return binstr if n > 0 else binstr_flip(binstr)\n\n\ndef flatten(lst):\n    return [item for sublist in lst for item in sublist]\n"
  }
]