[
  {
    "path": ".dockerignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n\n.vscode/\nmodels/\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n\n.vscode/\nconfigs/\nmodels/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n    - repo: https://github.com/pre-commit/pre-commit-hooks\n      rev: v2.3.0\n      hooks:\n          - id: check-yaml\n          - id: end-of-file-fixer\n          - id: trailing-whitespace\n    - repo: https://github.com/psf/black\n      rev: 22.10.0\n      hooks:\n          - id: black\n    - repo: https://github.com/pycqa/isort\n      rev: 5.12.0\n      hooks:\n          - id: isort\n"
  },
  {
    "path": "Changelog",
    "content": "# Changelog\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),\nand this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\n\n## [Unreleased]\n### Add\n- Use RaggedTensor for train and eval pipeline\n\n## [0.2.0] - 2021-07-29\n### Changed\n- New export script\n- New model build way\n\n## [0.1.1] - 2021-01-22\n### Changed\n- Reduce minimum TensorFlow version to 2.2\n- Better post processing add way\n- Use StaticHashTable instead of StringLookup layer\n\n## [0.1.0] - 2021-01-09\n### Added\n- Add <UNK> label\n- Add docker support\n- Add EditDistance metrics\n- Add Decoders for a truly end-to-end model [experimental]\n\n### Changed\n- Update minimum TensorFlow version to 2.3.0\n- Change img_height, img_width, img_channels to img_shape\n- Build a new data pipeline\n- Use new preprocessing\n- Change model_dir to save_dir in train.py\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM tensorflow/tensorflow:latest-gpu\n\nWORKDIR /workspace\n\nCOPY . /workspace\n\nRUN pip install --no-cache-dir -r requirements.txt\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019-2021 Huang Yiming\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Convolutional Recurrent Neural Network for End-to-End Text Recognition - TensorFlow 2\n\n![TensorFlow version](https://img.shields.io/badge/TensorFlow->=2.3-FF6F00?logo=tensorflow)\n![Python version](https://img.shields.io/badge/Python->=3.6-3776AB?logo=python)\n[![Paper](https://img.shields.io/badge/paper-arXiv:1507.05717-B3181B?logo=arXiv)](https://arxiv.org/abs/1507.05717)\n[![Zhihu](https://img.shields.io/badge/知乎-文本识别网络CRNN—实现简述-blue?logo=zhihu)](https://zhuanlan.zhihu.com/p/122512498)\n\nThis is a re-implementation of the CRNN network, build by TensorFlow 2. This repository may help you to understand how to build an End-to-End text recognition network easily. Here is the official [repo](https://github.com/bgshih/crnn) implemented by [bgshih](https://github.com/bgshih).\n\n## Abstract\n\nThis repo aims to build a simple, efficient text recognize network by using the various components of TensorFlow 2. The model build by the Keras API, the data pipeline build by `tf.data`, and training with `model.fit`, so we can use most of the functions provided by TensorFlow 2, such as `Tensorboard`, `Distribution strategy`, `TensorFlow Profiler` etc.\n\n## Installation\n\n```bash\n$ pip install -r requirements.txt\n```\n\n## Demo\n\nHere I provide an example model that trained on the Mjsynth dataset, this model can only predict 0-9 and a-z(ignore case).\n\n```bash\n$ wget https://github.com/FLming/CRNN.tf2/releases/download/v0.2.0/SavedModel.tgz\n$ tar xzvf SavedModel.tgz\n$ python tools/demo.py --images example/images/ --config configs/mjsynth.yml --model SavedModel\n```\n\nThen, You will see output like this:\n```\nPath: example/images/word_1.png, y_pred: [b'tiredness'], probability: [0.9998626]\nPath: example/images/word_3.png, y_pred: [b'a'], probability: [0.67493004]\nPath: example/images/2_Reimbursing_64165.jpg, y_pred: [b'reimbursing'], probability: [0.990946]\nPath: example/images/word_2.png, y_pred: [b'kills'], probability: [0.9994573]\nPath: example/images/1_Paintbrushes_55044.jpg, y_pred: [b'paintbrushes'], probability: [0.9984008]\nPath: example/images/3_Creationisms_17934.jpg, y_pred: [b'creationisms'], probability: [0.99792457]\n```\n\nAbout decode methods, sometimes the beam search method will be better than the greedy method, but it's costly.\n\n## Train\n\nBefore you start training, maybe you should [prepare](#Data-prepare) data first. All predictable characters are defined by the [table.txt](example/table.txt) file. The configuration of the training process is defined by the [yml](configs/mjsynth.yml) file.\n\nThis training script uses all GPUs by default, if you want to use a specific GPU, please set the `CUDA_VISIBLE_DEVICES` parameter.\n\n```bash\n$ python crnn/train.py --config configs/mjsynth.yml --save_dir PATH/TO/SAVE\n```\n\nThe training process can visualize in Tensorboard.\n\n```bash\n$ tensorboard --logdir PATH/TO/MODEL_DIR\n```\n\nFor more instructions, please refer to the [config](configs/mjsynth.yml) file.\n\n## Data prepare\n\nTo train this network, you should prepare a lookup table, images and corresponding labels. Example data is copy from [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [ICDAR2013](https://rrc.cvc.uab.es/?ch=2&com=introduction) dataset.\n\n### [Lookup table](./example/table.txt)\n\nThe file contains all characters and blank labels (in the last or any place both ok, but I find Tensorflow decoders can't change it now, so set it to last). By the way, you can write any word as blank.\n\n### Image data\n\nIt's an End-to-End method, so we don't need to indicate the position of the character in the image.\n\n![Paintbrushes](example/images/1_Paintbrushes_55044.jpg)\n![Creationisms](example/images/3_Creationisms_17934.jpg)\n![Reimbursing](example/images/2_Reimbursing_64165.jpg)\n\nThe labels corresponding to these three pictures are `Paintbrushes`, `Creationisms`, `Reimbursing`.\n\n### Annotation file\n\nWe should write the image path and its corresponding label to a text file in a certain format such as example data. The data input pipeline will automatically detect the support format. Customization is also very simple, please check out the [dataset factory](crnn/dataset_factory.py).\n\n#### Support format\n\n- [MJSynth](./example/mjsynth_annotation.txt)\n- [ICDAR2013/2015](./example/icdar2013_annotation.txt)\n- [Simple](./example/simple_annotation.txt) such as [example.jpg label]\n\n## Eval\n\n```bash\n$ python crnn/eval.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT\n```\n\n## Converte & Ecosystem\n\nThere are many components here to help us do other things. For example, deploy by `Tensorflow serving`. Before you deploy, you can pick up a good weight, and convertes model to `SavedModel` format by this command, it will add the post processing layer in the last and cull the optimizer:\n\n```bash\n$ python tools/export.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT --pre rescale --post greedy --output PATH/TO/OUTPUT\n```\n\nAnd now `Tensorflow lite` also can convert this model, that means you can deploy it to Android, iOS etc.\n\nNote. Decoders can't convert to `Tensorflow lite` because of the assets. Use the softmax layer or None.\n"
  },
  {
    "path": "crnn/dataset_factory.py",
    "content": "import os\nimport re\n\nimport tensorflow as tf\n\ntry:\n    AUTOTUNE = tf.data.AUTOTUNE\nexcept AttributeError:\n    # tf < 2.4.0\n    AUTOTUNE = tf.data.experimental.AUTOTUNE\n\n\nclass Dataset(tf.data.TextLineDataset):\n    def __init__(self, filename, **kwargs):\n        self.dirname = os.path.dirname(filename)\n        super().__init__(filename, **kwargs)\n\n    def parse_func(self, line):\n        raise NotImplementedError\n\n    def parse_line(self, line):\n        line = tf.strings.strip(line)\n        img_relative_path, label = self.parse_func(line)\n        img_path = tf.strings.join([self.dirname, os.sep, img_relative_path])\n        return img_path, label\n\n\nclass SimpleDataset(Dataset):\n    def parse_func(self, line):\n        splited_line = tf.strings.split(line)\n        img_relative_path, label = splited_line[0], splited_line[1]\n        return img_relative_path, label\n\n\nclass MJSynthDataset(Dataset):\n    def parse_func(self, line):\n        splited_line = tf.strings.split(line)\n        img_relative_path = splited_line[0]\n        label = tf.strings.split(img_relative_path, sep=\"_\")[1]\n        return img_relative_path, label\n\n\nclass ICDARDataset(Dataset):\n    def parse_func(self, line):\n        splited_line = tf.strings.split(line, sep=\",\")\n        img_relative_path, label = splited_line[0], splited_line[1]\n        label = tf.strings.strip(label)\n        label = tf.strings.regex_replace(label, r'\"', \"\")\n        return img_relative_path, label\n\n\nclass DatasetBuilder:\n    def __init__(\n        self,\n        table_path,\n        img_shape=(32, None, 3),\n        max_img_width=300,\n        ignore_case=False,\n    ):\n        # map unknown label to 0\n        self.table = tf.lookup.StaticHashTable(\n            tf.lookup.TextFileInitializer(\n                table_path,\n                tf.string,\n                tf.lookup.TextFileIndex.WHOLE_LINE,\n                tf.int64,\n                tf.lookup.TextFileIndex.LINE_NUMBER,\n            ),\n            0,\n        )\n        self.img_shape = img_shape\n        self.ignore_case = ignore_case\n        if img_shape[1] is None:\n            self.max_img_width = max_img_width\n            self.preserve_aspect_ratio = True\n        else:\n            self.preserve_aspect_ratio = False\n\n    @property\n    def num_classes(self):\n        return self.table.size()\n\n    def _parse_annotation(self, path):\n        with open(path) as f:\n            line = f.readline().strip()\n        if re.fullmatch(r\".*/*\\d+_.+_(\\d+)\\.\\w+ \\1\", line):\n            return MJSynthDataset(path)\n        elif re.fullmatch(r'.*/*word_\\d\\.\\w+, \".+\"', line):\n            return ICDARDataset(path)\n        elif re.fullmatch(r\".+\\.\\w+ .+\", line):\n            return SimpleDataset(path)\n        else:\n            raise ValueError(\"Unsupported annotation format\")\n\n    def _concatenate_ds(self, ann_paths):\n        datasets = [self._parse_annotation(path) for path in ann_paths]\n        concatenated_ds = datasets[0].map(datasets[0].parse_line)\n        for ds in datasets[1:]:\n            ds = ds.map(ds.parse_line)\n            concatenated_ds = concatenated_ds.concatenate(ds)\n        return concatenated_ds\n\n    def _decode_img(self, filename, label):\n        img = tf.io.read_file(filename)\n        img = tf.io.decode_jpeg(img, channels=self.img_shape[-1])\n        if self.preserve_aspect_ratio:\n            img_shape = tf.shape(img)\n            scale_factor = self.img_shape[0] / img_shape[0]\n            img_width = scale_factor * tf.cast(img_shape[1], tf.float64)\n            img_width = tf.cast(img_width, tf.int32)\n        else:\n            img_width = self.img_shape[1]\n        img = tf.image.resize(img, (self.img_shape[0], img_width)) / 255.0\n        return img, label\n\n    def _filter_img(self, img, label):\n        img_shape = tf.shape(img)\n        return img_shape[1] < self.max_img_width\n\n    def _tokenize(self, imgs, labels):\n        chars = tf.strings.unicode_split(labels, \"UTF-8\")\n        tokens = tf.ragged.map_flat_values(self.table.lookup, chars)\n        # TODO(hym) Waiting for official support to use RaggedTensor in keras\n        tokens = tokens.to_sparse()\n        return imgs, tokens\n\n    def __call__(self, ann_paths, batch_size, is_training):\n        ds = self._concatenate_ds(ann_paths)\n        if self.ignore_case:\n            ds = ds.map(lambda x, y: (x, tf.strings.lower(y)))\n        if is_training:\n            ds = ds.shuffle(buffer_size=10000)\n        ds = ds.map(self._decode_img, AUTOTUNE)\n        if self.preserve_aspect_ratio and batch_size != 1:\n            ds = ds.filter(self._filter_img)\n            ds = ds.padded_batch(batch_size, drop_remainder=is_training)\n        else:\n            ds = ds.batch(batch_size, drop_remainder=is_training)\n        ds = ds.map(self._tokenize, AUTOTUNE)\n        ds = ds.prefetch(AUTOTUNE)\n        return ds\n"
  },
  {
    "path": "crnn/decoders.py",
    "content": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass CTCDecoder(keras.layers.Layer):\n    def __init__(self, table_path, **kwargs):\n        super().__init__(**kwargs)\n        self.table = tf.lookup.StaticHashTable(\n            tf.lookup.TextFileInitializer(\n                table_path,\n                tf.int64,\n                tf.lookup.TextFileIndex.LINE_NUMBER,\n                tf.string,\n                tf.lookup.TextFileIndex.WHOLE_LINE,\n            ),\n            \"\",\n        )\n\n    def detokenize(self, x):\n        x = tf.RaggedTensor.from_sparse(x)\n        x = tf.ragged.map_flat_values(self.table.lookup, x)\n        strings = tf.strings.reduce_join(x, axis=1)\n        return strings\n\n\nclass CTCGreedyDecoder(CTCDecoder):\n    def __init__(self, table_path, merge_repeated=True, **kwargs):\n        super().__init__(table_path, **kwargs)\n        self.merge_repeated = merge_repeated\n\n    def call(self, inputs):\n        input_shape = tf.shape(inputs)\n        sequence_length = tf.fill([input_shape[0]], input_shape[1])\n        decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder(\n            tf.transpose(inputs, perm=[1, 0, 2]),\n            sequence_length,\n            self.merge_repeated,\n        )\n        strings = self.detokenize(decoded[0])\n        labels = tf.cast(decoded[0], tf.int32)\n        loss = tf.nn.ctc_loss(\n            labels=labels,\n            logits=inputs,\n            label_length=None,\n            logit_length=sequence_length,\n            logits_time_major=False,\n            blank_index=-1,\n        )\n        probability = tf.math.exp(-loss)\n        return strings, probability\n\n\nclass CTCBeamSearchDecoder(CTCDecoder):\n    def __init__(self, table_path, beam_width=100, top_paths=1, **kwargs):\n        super().__init__(table_path, **kwargs)\n        self.beam_width = beam_width\n        self.top_paths = top_paths\n\n    def call(self, inputs):\n        input_shape = tf.shape(inputs)\n        decoded, log_probability = tf.nn.ctc_beam_search_decoder(\n            tf.transpose(inputs, perm=[1, 0, 2]),\n            tf.fill([input_shape[0]], input_shape[1]),\n            self.beam_width,\n            self.top_paths,\n        )\n        strings = []\n        for i in range(self.top_paths):\n            strings.append(self.detokenize(decoded[i]))\n        strings = tf.concat(strings, 1)\n        probability = tf.math.exp(log_probability)\n        return strings, probability\n"
  },
  {
    "path": "crnn/eval.py",
    "content": "import argparse\nimport pprint\n\nimport yaml\n\nfrom dataset_factory import DatasetBuilder\nfrom losses import CTCLoss\nfrom metrics import EditDistance\nfrom metrics import SequenceAccuracy\nfrom models import build_model\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--config\", type=str, required=True, help=\"The config file path.\"\n)\nparser.add_argument(\n    \"--weight\", type=str, required=True, help=\"The saved weight path.\"\n)\nargs = parser.parse_args()\n\nwith open(args.config) as f:\n    config = yaml.load(f, Loader=yaml.Loader)[\"eval\"]\npprint.pprint(config)\n\ndataset_builder = DatasetBuilder(**config[\"dataset_builder\"])\nds = dataset_builder(config[\"ann_paths\"], config[\"batch_size\"], False)\nmodel = build_model(\n    dataset_builder.num_classes,\n    weight=args.weight,\n    img_shape=config[\"dataset_builder\"][\"img_shape\"],\n)\nmodel.compile(loss=CTCLoss(), metrics=[SequenceAccuracy(), EditDistance()])\nmodel.evaluate(ds)\n"
  },
  {
    "path": "crnn/export.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport yaml\nfrom tensorflow import keras\n\nfrom decoders import CTCBeamSearchDecoder\nfrom decoders import CTCGreedyDecoder\nfrom models import build_model\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--config\", type=Path, required=True, help=\"The config file path.\"\n)\nparser.add_argument(\n    \"--weight\",\n    type=str,\n    required=True,\n    default=\"\",\n    help=\"The saved weight path.\",\n)\nparser.add_argument(\"--pre\", type=str, help=\"pre processing.\")\nparser.add_argument(\"--post\", type=str, help=\"Post processing.\")\nparser.add_argument(\n    \"--output\", type=str, required=True, help=\"The output path.\"\n)\nargs = parser.parse_args()\n\nwith args.config.open() as f:\n    config = yaml.load(f, Loader=yaml.Loader)[\"dataset_builder\"]\n\nwith open(config[\"table_path\"]) as f:\n    num_classes = len(f.readlines())\n\nif args.pre == \"rescale\":\n    preprocess = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)\nelse:\n    preprocess = None\n\nif args.post == \"softmax\":\n    postprocess = keras.layers.Softmax()\nelif args.post == \"greedy\":\n    postprocess = CTCGreedyDecoder(config[\"table_path\"])\nelif args.post == \"beam_search\":\n    postprocess = CTCBeamSearchDecoder(config[\"table_path\"])\nelse:\n    postprocess = None\n\nmodel = build_model(\n    num_classes,\n    weight=args.weight,\n    preprocess=preprocess,\n    postprocess=postprocess,\n    img_shape=config[\"img_shape\"],\n)\nmodel.summary()\nmodel.save(args.output)\n"
  },
  {
    "path": "crnn/losses.py",
    "content": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass CTCLoss(keras.losses.Loss):\n    \"\"\"A class that wraps the function of tf.nn.ctc_loss.\n\n    Attributes:\n        logits_time_major: If False (default) , shape is [batch, time, logits],\n            If True, logits is shaped [time, batch, logits].\n        blank_index: Set the class index to use for the blank label. default is\n            -1 (num_classes - 1).\n    \"\"\"\n\n    def __init__(\n        self, logits_time_major=False, blank_index=-1, name=\"ctc_loss\"\n    ):\n        super().__init__(name=name)\n        self.logits_time_major = logits_time_major\n        self.blank_index = blank_index\n\n    def call(self, y_true, y_pred):\n        \"\"\"Computes CTC (Connectionist Temporal Classification) loss. work on\n        CPU, because y_true is a SparseTensor.\n        \"\"\"\n        y_true = tf.cast(y_true, tf.int32)\n        y_pred_shape = tf.shape(y_pred)\n        logit_length = tf.fill([y_pred_shape[0]], y_pred_shape[1])\n        loss = tf.nn.ctc_loss(\n            labels=y_true,\n            logits=y_pred,\n            label_length=None,\n            logit_length=logit_length,\n            logits_time_major=self.logits_time_major,\n            blank_index=self.blank_index,\n        )\n        return tf.math.reduce_mean(loss)\n"
  },
  {
    "path": "crnn/metrics.py",
    "content": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass SequenceAccuracy(keras.metrics.Metric):\n    def __init__(self, name=\"sequence_accuracy\", **kwargs):\n        super().__init__(name=name, **kwargs)\n        self.total = self.add_weight(name=\"total\", initializer=\"zeros\")\n        self.count = self.add_weight(name=\"count\", initializer=\"zeros\")\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        def sparse2dense(tensor, shape):\n            tensor = tf.sparse.reset_shape(tensor, shape)\n            tensor = tf.sparse.to_dense(tensor, default_value=-1)\n            tensor = tf.cast(tensor, tf.float32)\n            return tensor\n\n        y_true_shape = tf.shape(y_true)\n        batch_size = y_true_shape[0]\n        y_pred_shape = tf.shape(y_pred)\n        max_width = tf.math.maximum(y_true_shape[1], y_pred_shape[1])\n        logit_length = tf.fill([batch_size], y_pred_shape[1])\n        decoded, _ = tf.nn.ctc_greedy_decoder(\n            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),\n            sequence_length=logit_length,\n        )\n        y_true = sparse2dense(y_true, [batch_size, max_width])\n        y_pred = sparse2dense(decoded[0], [batch_size, max_width])\n        num_errors = tf.math.reduce_any(\n            tf.math.not_equal(y_true, y_pred), axis=1\n        )\n        num_errors = tf.cast(num_errors, tf.float32)\n        num_errors = tf.math.reduce_sum(num_errors)\n        batch_size = tf.cast(batch_size, tf.float32)\n        self.total.assign_add(batch_size)\n        self.count.assign_add(batch_size - num_errors)\n\n    def result(self):\n        return self.count / self.total\n\n    def reset_states(self):\n        self.count.assign(0)\n        self.total.assign(0)\n\n\nclass EditDistance(keras.metrics.Metric):\n    def __init__(self, name=\"edit_distance\", **kwargs):\n        super().__init__(name=name, **kwargs)\n        self.total = self.add_weight(name=\"total\", initializer=\"zeros\")\n        self.sum_distance = self.add_weight(\n            name=\"sum_distance\", initializer=\"zeros\"\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        y_pred_shape = tf.shape(y_pred)\n        batch_size = y_pred_shape[0]\n        logit_length = tf.fill([batch_size], y_pred_shape[1])\n        decoded, _ = tf.nn.ctc_greedy_decoder(\n            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),\n            sequence_length=logit_length,\n        )\n        sum_distance = tf.math.reduce_sum(tf.edit_distance(decoded[0], y_true))\n        batch_size = tf.cast(batch_size, tf.float32)\n        self.sum_distance.assign_add(sum_distance)\n        self.total.assign_add(batch_size)\n\n    def result(self):\n        return self.sum_distance / self.total\n\n    def reset_states(self):\n        self.sum_distance.assign(0)\n        self.total.assign(0)\n"
  },
  {
    "path": "crnn/models.py",
    "content": "from tensorflow import keras\nfrom tensorflow.keras import layers\n\n\ndef vgg_style(x):\n    \"\"\"\n    The original feature extraction structure from CRNN paper.\n    Related paper: https://ieeexplore.ieee.org/abstract/document/7801919\n    \"\"\"\n    x = layers.Conv2D(64, 3, padding=\"same\", activation=\"relu\", name=\"conv1\")(x)\n    x = layers.MaxPool2D(pool_size=2, padding=\"same\", name=\"pool1\")(x)\n\n    x = layers.Conv2D(128, 3, padding=\"same\", activation=\"relu\", name=\"conv2\")(\n        x\n    )\n    x = layers.MaxPool2D(pool_size=2, padding=\"same\", name=\"pool2\")(x)\n\n    x = layers.Conv2D(256, 3, padding=\"same\", use_bias=False, name=\"conv3\")(x)\n    x = layers.BatchNormalization(name=\"bn3\")(x)\n    x = layers.Activation(\"relu\", name=\"relu3\")(x)\n    x = layers.Conv2D(256, 3, padding=\"same\", activation=\"relu\", name=\"conv4\")(\n        x\n    )\n    x = layers.MaxPool2D(\n        pool_size=2, strides=(2, 1), padding=\"same\", name=\"pool4\"\n    )(x)\n\n    x = layers.Conv2D(512, 3, padding=\"same\", use_bias=False, name=\"conv5\")(x)\n    x = layers.BatchNormalization(name=\"bn5\")(x)\n    x = layers.Activation(\"relu\", name=\"relu5\")(x)\n    x = layers.Conv2D(512, 3, padding=\"same\", activation=\"relu\", name=\"conv6\")(\n        x\n    )\n    x = layers.MaxPool2D(\n        pool_size=2, strides=(2, 1), padding=\"same\", name=\"pool6\"\n    )(x)\n\n    x = layers.Conv2D(512, 2, use_bias=False, name=\"conv7\")(x)\n    x = layers.BatchNormalization(name=\"bn7\")(x)\n    x = layers.Activation(\"relu\", name=\"relu7\")(x)\n\n    x = layers.Reshape((-1, 512), name=\"reshape7\")(x)\n    return x\n\n\ndef build_model(\n    num_classes,\n    weight=None,\n    preprocess=None,\n    postprocess=None,\n    img_shape=(32, None, 3),\n    model_name=\"crnn\",\n):\n    x = img_input = keras.Input(shape=img_shape)\n    if preprocess is not None:\n        x = preprocess(x)\n\n    x = vgg_style(x)\n    x = layers.Bidirectional(\n        layers.LSTM(units=256, return_sequences=True), name=\"bi_lstm1\"\n    )(x)\n    x = layers.Bidirectional(\n        layers.LSTM(units=256, return_sequences=True), name=\"bi_lstm2\"\n    )(x)\n    x = layers.Dense(units=num_classes, name=\"logits\")(x)\n\n    if postprocess is not None:\n        x = postprocess(x)\n\n    model = keras.Model(inputs=img_input, outputs=x, name=model_name)\n    if weight is not None:\n        model.load_weights(weight, by_name=True, skip_mismatch=True)\n    return model\n"
  },
  {
    "path": "crnn/train.py",
    "content": "import argparse\nimport pprint\nimport shutil\nfrom pathlib import Path\n\nimport tensorflow as tf\nimport yaml\nfrom tensorflow import keras\n\nfrom dataset_factory import DatasetBuilder\nfrom losses import CTCLoss\nfrom metrics import SequenceAccuracy\nfrom models import build_model\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--config\", type=Path, required=True, help=\"The config file path.\"\n)\nparser.add_argument(\n    \"--save_dir\",\n    type=Path,\n    required=True,\n    help=\"The path to save the models, logs, etc.\",\n)\nargs = parser.parse_args()\n\nwith args.config.open() as f:\n    config = yaml.load(f, Loader=yaml.Loader)[\"train\"]\npprint.pprint(config)\n\nargs.save_dir.mkdir(exist_ok=True)\nif list(args.save_dir.iterdir()):\n    raise ValueError(f\"{args.save_dir} is not a empty folder\")\nshutil.copy(args.config, args.save_dir / args.config.name)\n\nstrategy = tf.distribute.MirroredStrategy()\nbatch_size = config[\"batch_size_per_replica\"] * strategy.num_replicas_in_sync\n\ndataset_builder = DatasetBuilder(**config[\"dataset_builder\"])\ntrain_ds = dataset_builder(config[\"train_ann_paths\"], batch_size, True)\nval_ds = dataset_builder(config[\"val_ann_paths\"], batch_size, False)\n\nwith strategy.scope():\n    lr_schedule = keras.optimizers.schedules.CosineDecay(\n        **config[\"lr_schedule\"]\n    )\n    model = build_model(\n        dataset_builder.num_classes,\n        weight=config.get(\"weight\"),\n        img_shape=config[\"dataset_builder\"][\"img_shape\"],\n    )\n    model.compile(\n        optimizer=keras.optimizers.Adam(lr_schedule),\n        loss=CTCLoss(),\n        metrics=[SequenceAccuracy()],\n    )\n\nmodel.summary()\n\nmodel_prefix = \"{epoch}_{val_loss:.4f}_{val_sequence_accuracy:.4f}\"\nmodel_path = f\"{args.save_dir}/{model_prefix}.h5\"\ncallbacks = [\n    keras.callbacks.ModelCheckpoint(model_path, save_weights_only=True),\n    keras.callbacks.TensorBoard(\n        log_dir=f\"{args.save_dir}/logs\", **config[\"tensorboard\"]\n    ),\n]\n\nmodel.fit(\n    train_ds,\n    epochs=config[\"epochs\"],\n    callbacks=callbacks,\n    validation_data=val_ds,\n)\n"
  },
  {
    "path": "example/icdar2013_annotation.txt",
    "content": "images/word_1.png, \"Tiredness\"\nimages/word_2.png, \"kills\"\nimages/word_3.png, \"A\""
  },
  {
    "path": "example/mjsynth_annotation.txt",
    "content": "images/1_Paintbrushes_55044.jpg 55044\nimages/2_Reimbursing_64165.jpg 64165\nimages/3_Creationisms_17934.jpg 17934"
  },
  {
    "path": "example/simple_annotation.txt",
    "content": "images/1_Paintbrushes_55044.jpg Paintbrushes\nimages/2_Reimbursing_64165.jpg Reimbursing\nimages/word_1.png Tiredness\nimages/word_2.png kills"
  },
  {
    "path": "example/table.txt",
    "content": "<UNK>\n0\n1\n2\n3\n4\n5\n6\n7\n8\n9\na\nb\nc\nd\ne\nf\ng\nh\ni\nj\nk\nl\nm\nn\no\np\nq\nr\ns\nt\nu\nv\nw\nx\ny\nz\n<BLK>"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\nline-length = 80\ntarget-version = ['py37']\ninclude = '\\.pyi?$'\n# 'extend-exclude' excludes files or directories in addition to the defaults\nextend-exclude = '''\n# A regex preceded with ^/ will apply only to files and directories\n# in the root of the project.\n'''\n\n[tool.isort]\nprofile = \"black\"\nsrc_paths = [\"crnn\"]\nline_length = 80\nforce_single_line = true\n"
  },
  {
    "path": "requirements.txt",
    "content": "pyyaml\ntensorflow >= 2.3"
  },
  {
    "path": "requirements_dev.txt",
    "content": "-r requirements.txt\n\npre-commit\n"
  },
  {
    "path": "tools/demo.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport tensorflow as tf\nimport yaml\nfrom tensorflow import keras\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--images\", type=str, required=True, help=\"Image file or folder path.\"\n)\nparser.add_argument(\n    \"--config\", type=Path, required=True, help=\"The config file path.\"\n)\nparser.add_argument(\"--model\", type=str, required=True, help=\"The saved model.\")\nargs = parser.parse_args()\n\nwith args.config.open() as f:\n    config = yaml.load(f, Loader=yaml.Loader)[\"dataset_builder\"]\n\n\ndef read_img_and_resize(path, shape):\n    img = tf.io.read_file(path)\n    img = tf.io.decode_jpeg(img, channels=shape[2])\n    if shape[1] is None:\n        img_shape = tf.shape(img)\n        scale_factor = shape[0] / img_shape[0]\n        img_width = scale_factor * tf.cast(img_shape[1], tf.float64)\n        img_width = tf.cast(img_width, tf.int32)\n    else:\n        img_width = shape[1]\n    img = tf.image.resize(img, (shape[0], img_width))\n    return img\n\n\nmodel = keras.models.load_model(args.model, compile=False)\n\np = Path(args.images)\nimg_paths = p.iterdir() if p.is_dir() else [p]\nfor img_path in img_paths:\n    img = read_img_and_resize(str(img_path), config[\"img_shape\"])\n    img = tf.expand_dims(img, 0)\n    outputs = model(img)\n    print(\n        f\"Path: {img_path}, y_pred: {outputs[0].numpy()}, \"\n        f\"probability: {outputs[1].numpy()}\"\n    )\n"
  }
]