Full Code of FLming/CRNN.tf2 for AI

main ccf2a7a7a21f cached
23 files
27.4 KB
7.6k tokens
44 symbols
1 requests
Download .txt
Repository: FLming/CRNN.tf2
Branch: main
Commit: ccf2a7a7a21f
Files: 23
Total size: 27.4 KB

Directory structure:
gitextract_9oovgxx4/

├── .dockerignore
├── .gitignore
├── .pre-commit-config.yaml
├── Changelog
├── Dockerfile
├── LICENSE
├── README.md
├── crnn/
│   ├── dataset_factory.py
│   ├── decoders.py
│   ├── eval.py
│   ├── export.py
│   ├── losses.py
│   ├── metrics.py
│   ├── models.py
│   └── train.py
├── example/
│   ├── icdar2013_annotation.txt
│   ├── mjsynth_annotation.txt
│   ├── simple_annotation.txt
│   └── table.txt
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
└── tools/
    └── demo.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .dockerignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/

.vscode/
models/


================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/

.vscode/
configs/
models/


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
    - repo: https://github.com/pre-commit/pre-commit-hooks
      rev: v2.3.0
      hooks:
          - id: check-yaml
          - id: end-of-file-fixer
          - id: trailing-whitespace
    - repo: https://github.com/psf/black
      rev: 22.10.0
      hooks:
          - id: black
    - repo: https://github.com/pycqa/isort
      rev: 5.12.0
      hooks:
          - id: isort


================================================
FILE: Changelog
================================================
# Changelog
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Add
- Use RaggedTensor for train and eval pipeline

## [0.2.0] - 2021-07-29
### Changed
- New export script
- New model build way

## [0.1.1] - 2021-01-22
### Changed
- Reduce minimum TensorFlow version to 2.2
- Better post processing add way
- Use StaticHashTable instead of StringLookup layer

## [0.1.0] - 2021-01-09
### Added
- Add <UNK> label
- Add docker support
- Add EditDistance metrics
- Add Decoders for a truly end-to-end model [experimental]

### Changed
- Update minimum TensorFlow version to 2.3.0
- Change img_height, img_width, img_channels to img_shape
- Build a new data pipeline
- Use new preprocessing
- Change model_dir to save_dir in train.py


================================================
FILE: Dockerfile
================================================
FROM tensorflow/tensorflow:latest-gpu

WORKDIR /workspace

COPY . /workspace

RUN pip install --no-cache-dir -r requirements.txt


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019-2021 Huang Yiming

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Convolutional Recurrent Neural Network for End-to-End Text Recognition - TensorFlow 2

![TensorFlow version](https://img.shields.io/badge/TensorFlow->=2.3-FF6F00?logo=tensorflow)
![Python version](https://img.shields.io/badge/Python->=3.6-3776AB?logo=python)
[![Paper](https://img.shields.io/badge/paper-arXiv:1507.05717-B3181B?logo=arXiv)](https://arxiv.org/abs/1507.05717)
[![Zhihu](https://img.shields.io/badge/知乎-文本识别网络CRNN—实现简述-blue?logo=zhihu)](https://zhuanlan.zhihu.com/p/122512498)

This is a re-implementation of the CRNN network, build by TensorFlow 2. This repository may help you to understand how to build an End-to-End text recognition network easily. Here is the official [repo](https://github.com/bgshih/crnn) implemented by [bgshih](https://github.com/bgshih).

## Abstract

This repo aims to build a simple, efficient text recognize network by using the various components of TensorFlow 2. The model build by the Keras API, the data pipeline build by `tf.data`, and training with `model.fit`, so we can use most of the functions provided by TensorFlow 2, such as `Tensorboard`, `Distribution strategy`, `TensorFlow Profiler` etc.

## Installation

```bash
$ pip install -r requirements.txt
```

## Demo

Here I provide an example model that trained on the Mjsynth dataset, this model can only predict 0-9 and a-z(ignore case).

```bash
$ wget https://github.com/FLming/CRNN.tf2/releases/download/v0.2.0/SavedModel.tgz
$ tar xzvf SavedModel.tgz
$ python tools/demo.py --images example/images/ --config configs/mjsynth.yml --model SavedModel
```

Then, You will see output like this:
```
Path: example/images/word_1.png, y_pred: [b'tiredness'], probability: [0.9998626]
Path: example/images/word_3.png, y_pred: [b'a'], probability: [0.67493004]
Path: example/images/2_Reimbursing_64165.jpg, y_pred: [b'reimbursing'], probability: [0.990946]
Path: example/images/word_2.png, y_pred: [b'kills'], probability: [0.9994573]
Path: example/images/1_Paintbrushes_55044.jpg, y_pred: [b'paintbrushes'], probability: [0.9984008]
Path: example/images/3_Creationisms_17934.jpg, y_pred: [b'creationisms'], probability: [0.99792457]
```

About decode methods, sometimes the beam search method will be better than the greedy method, but it's costly.

## Train

Before you start training, maybe you should [prepare](#Data-prepare) data first. All predictable characters are defined by the [table.txt](example/table.txt) file. The configuration of the training process is defined by the [yml](configs/mjsynth.yml) file.

This training script uses all GPUs by default, if you want to use a specific GPU, please set the `CUDA_VISIBLE_DEVICES` parameter.

```bash
$ python crnn/train.py --config configs/mjsynth.yml --save_dir PATH/TO/SAVE
```

The training process can visualize in Tensorboard.

```bash
$ tensorboard --logdir PATH/TO/MODEL_DIR
```

For more instructions, please refer to the [config](configs/mjsynth.yml) file.

## Data prepare

To train this network, you should prepare a lookup table, images and corresponding labels. Example data is copy from [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [ICDAR2013](https://rrc.cvc.uab.es/?ch=2&com=introduction) dataset.

### [Lookup table](./example/table.txt)

The file contains all characters and blank labels (in the last or any place both ok, but I find Tensorflow decoders can't change it now, so set it to last). By the way, you can write any word as blank.

### Image data

It's an End-to-End method, so we don't need to indicate the position of the character in the image.

![Paintbrushes](example/images/1_Paintbrushes_55044.jpg)
![Creationisms](example/images/3_Creationisms_17934.jpg)
![Reimbursing](example/images/2_Reimbursing_64165.jpg)

The labels corresponding to these three pictures are `Paintbrushes`, `Creationisms`, `Reimbursing`.

### Annotation file

We should write the image path and its corresponding label to a text file in a certain format such as example data. The data input pipeline will automatically detect the support format. Customization is also very simple, please check out the [dataset factory](crnn/dataset_factory.py).

#### Support format

- [MJSynth](./example/mjsynth_annotation.txt)
- [ICDAR2013/2015](./example/icdar2013_annotation.txt)
- [Simple](./example/simple_annotation.txt) such as [example.jpg label]

## Eval

```bash
$ python crnn/eval.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT
```

## Converte & Ecosystem

There are many components here to help us do other things. For example, deploy by `Tensorflow serving`. Before you deploy, you can pick up a good weight, and convertes model to `SavedModel` format by this command, it will add the post processing layer in the last and cull the optimizer:

```bash
$ python tools/export.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT --pre rescale --post greedy --output PATH/TO/OUTPUT
```

And now `Tensorflow lite` also can convert this model, that means you can deploy it to Android, iOS etc.

Note. Decoders can't convert to `Tensorflow lite` because of the assets. Use the softmax layer or None.


================================================
FILE: crnn/dataset_factory.py
================================================
import os
import re

import tensorflow as tf

try:
    AUTOTUNE = tf.data.AUTOTUNE
except AttributeError:
    # tf < 2.4.0
    AUTOTUNE = tf.data.experimental.AUTOTUNE


class Dataset(tf.data.TextLineDataset):
    def __init__(self, filename, **kwargs):
        self.dirname = os.path.dirname(filename)
        super().__init__(filename, **kwargs)

    def parse_func(self, line):
        raise NotImplementedError

    def parse_line(self, line):
        line = tf.strings.strip(line)
        img_relative_path, label = self.parse_func(line)
        img_path = tf.strings.join([self.dirname, os.sep, img_relative_path])
        return img_path, label


class SimpleDataset(Dataset):
    def parse_func(self, line):
        splited_line = tf.strings.split(line)
        img_relative_path, label = splited_line[0], splited_line[1]
        return img_relative_path, label


class MJSynthDataset(Dataset):
    def parse_func(self, line):
        splited_line = tf.strings.split(line)
        img_relative_path = splited_line[0]
        label = tf.strings.split(img_relative_path, sep="_")[1]
        return img_relative_path, label


class ICDARDataset(Dataset):
    def parse_func(self, line):
        splited_line = tf.strings.split(line, sep=",")
        img_relative_path, label = splited_line[0], splited_line[1]
        label = tf.strings.strip(label)
        label = tf.strings.regex_replace(label, r'"', "")
        return img_relative_path, label


class DatasetBuilder:
    def __init__(
        self,
        table_path,
        img_shape=(32, None, 3),
        max_img_width=300,
        ignore_case=False,
    ):
        # map unknown label to 0
        self.table = tf.lookup.StaticHashTable(
            tf.lookup.TextFileInitializer(
                table_path,
                tf.string,
                tf.lookup.TextFileIndex.WHOLE_LINE,
                tf.int64,
                tf.lookup.TextFileIndex.LINE_NUMBER,
            ),
            0,
        )
        self.img_shape = img_shape
        self.ignore_case = ignore_case
        if img_shape[1] is None:
            self.max_img_width = max_img_width
            self.preserve_aspect_ratio = True
        else:
            self.preserve_aspect_ratio = False

    @property
    def num_classes(self):
        return self.table.size()

    def _parse_annotation(self, path):
        with open(path) as f:
            line = f.readline().strip()
        if re.fullmatch(r".*/*\d+_.+_(\d+)\.\w+ \1", line):
            return MJSynthDataset(path)
        elif re.fullmatch(r'.*/*word_\d\.\w+, ".+"', line):
            return ICDARDataset(path)
        elif re.fullmatch(r".+\.\w+ .+", line):
            return SimpleDataset(path)
        else:
            raise ValueError("Unsupported annotation format")

    def _concatenate_ds(self, ann_paths):
        datasets = [self._parse_annotation(path) for path in ann_paths]
        concatenated_ds = datasets[0].map(datasets[0].parse_line)
        for ds in datasets[1:]:
            ds = ds.map(ds.parse_line)
            concatenated_ds = concatenated_ds.concatenate(ds)
        return concatenated_ds

    def _decode_img(self, filename, label):
        img = tf.io.read_file(filename)
        img = tf.io.decode_jpeg(img, channels=self.img_shape[-1])
        if self.preserve_aspect_ratio:
            img_shape = tf.shape(img)
            scale_factor = self.img_shape[0] / img_shape[0]
            img_width = scale_factor * tf.cast(img_shape[1], tf.float64)
            img_width = tf.cast(img_width, tf.int32)
        else:
            img_width = self.img_shape[1]
        img = tf.image.resize(img, (self.img_shape[0], img_width)) / 255.0
        return img, label

    def _filter_img(self, img, label):
        img_shape = tf.shape(img)
        return img_shape[1] < self.max_img_width

    def _tokenize(self, imgs, labels):
        chars = tf.strings.unicode_split(labels, "UTF-8")
        tokens = tf.ragged.map_flat_values(self.table.lookup, chars)
        # TODO(hym) Waiting for official support to use RaggedTensor in keras
        tokens = tokens.to_sparse()
        return imgs, tokens

    def __call__(self, ann_paths, batch_size, is_training):
        ds = self._concatenate_ds(ann_paths)
        if self.ignore_case:
            ds = ds.map(lambda x, y: (x, tf.strings.lower(y)))
        if is_training:
            ds = ds.shuffle(buffer_size=10000)
        ds = ds.map(self._decode_img, AUTOTUNE)
        if self.preserve_aspect_ratio and batch_size != 1:
            ds = ds.filter(self._filter_img)
            ds = ds.padded_batch(batch_size, drop_remainder=is_training)
        else:
            ds = ds.batch(batch_size, drop_remainder=is_training)
        ds = ds.map(self._tokenize, AUTOTUNE)
        ds = ds.prefetch(AUTOTUNE)
        return ds


================================================
FILE: crnn/decoders.py
================================================
import tensorflow as tf
from tensorflow import keras


class CTCDecoder(keras.layers.Layer):
    def __init__(self, table_path, **kwargs):
        super().__init__(**kwargs)
        self.table = tf.lookup.StaticHashTable(
            tf.lookup.TextFileInitializer(
                table_path,
                tf.int64,
                tf.lookup.TextFileIndex.LINE_NUMBER,
                tf.string,
                tf.lookup.TextFileIndex.WHOLE_LINE,
            ),
            "",
        )

    def detokenize(self, x):
        x = tf.RaggedTensor.from_sparse(x)
        x = tf.ragged.map_flat_values(self.table.lookup, x)
        strings = tf.strings.reduce_join(x, axis=1)
        return strings


class CTCGreedyDecoder(CTCDecoder):
    def __init__(self, table_path, merge_repeated=True, **kwargs):
        super().__init__(table_path, **kwargs)
        self.merge_repeated = merge_repeated

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        sequence_length = tf.fill([input_shape[0]], input_shape[1])
        decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder(
            tf.transpose(inputs, perm=[1, 0, 2]),
            sequence_length,
            self.merge_repeated,
        )
        strings = self.detokenize(decoded[0])
        labels = tf.cast(decoded[0], tf.int32)
        loss = tf.nn.ctc_loss(
            labels=labels,
            logits=inputs,
            label_length=None,
            logit_length=sequence_length,
            logits_time_major=False,
            blank_index=-1,
        )
        probability = tf.math.exp(-loss)
        return strings, probability


class CTCBeamSearchDecoder(CTCDecoder):
    def __init__(self, table_path, beam_width=100, top_paths=1, **kwargs):
        super().__init__(table_path, **kwargs)
        self.beam_width = beam_width
        self.top_paths = top_paths

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        decoded, log_probability = tf.nn.ctc_beam_search_decoder(
            tf.transpose(inputs, perm=[1, 0, 2]),
            tf.fill([input_shape[0]], input_shape[1]),
            self.beam_width,
            self.top_paths,
        )
        strings = []
        for i in range(self.top_paths):
            strings.append(self.detokenize(decoded[i]))
        strings = tf.concat(strings, 1)
        probability = tf.math.exp(log_probability)
        return strings, probability


================================================
FILE: crnn/eval.py
================================================
import argparse
import pprint

import yaml

from dataset_factory import DatasetBuilder
from losses import CTCLoss
from metrics import EditDistance
from metrics import SequenceAccuracy
from models import build_model

parser = argparse.ArgumentParser()
parser.add_argument(
    "--config", type=str, required=True, help="The config file path."
)
parser.add_argument(
    "--weight", type=str, required=True, help="The saved weight path."
)
args = parser.parse_args()

with open(args.config) as f:
    config = yaml.load(f, Loader=yaml.Loader)["eval"]
pprint.pprint(config)

dataset_builder = DatasetBuilder(**config["dataset_builder"])
ds = dataset_builder(config["ann_paths"], config["batch_size"], False)
model = build_model(
    dataset_builder.num_classes,
    weight=args.weight,
    img_shape=config["dataset_builder"]["img_shape"],
)
model.compile(loss=CTCLoss(), metrics=[SequenceAccuracy(), EditDistance()])
model.evaluate(ds)


================================================
FILE: crnn/export.py
================================================
import argparse
from pathlib import Path

import yaml
from tensorflow import keras

from decoders import CTCBeamSearchDecoder
from decoders import CTCGreedyDecoder
from models import build_model

parser = argparse.ArgumentParser()
parser.add_argument(
    "--config", type=Path, required=True, help="The config file path."
)
parser.add_argument(
    "--weight",
    type=str,
    required=True,
    default="",
    help="The saved weight path.",
)
parser.add_argument("--pre", type=str, help="pre processing.")
parser.add_argument("--post", type=str, help="Post processing.")
parser.add_argument(
    "--output", type=str, required=True, help="The output path."
)
args = parser.parse_args()

with args.config.open() as f:
    config = yaml.load(f, Loader=yaml.Loader)["dataset_builder"]

with open(config["table_path"]) as f:
    num_classes = len(f.readlines())

if args.pre == "rescale":
    preprocess = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)
else:
    preprocess = None

if args.post == "softmax":
    postprocess = keras.layers.Softmax()
elif args.post == "greedy":
    postprocess = CTCGreedyDecoder(config["table_path"])
elif args.post == "beam_search":
    postprocess = CTCBeamSearchDecoder(config["table_path"])
else:
    postprocess = None

model = build_model(
    num_classes,
    weight=args.weight,
    preprocess=preprocess,
    postprocess=postprocess,
    img_shape=config["img_shape"],
)
model.summary()
model.save(args.output)


================================================
FILE: crnn/losses.py
================================================
import tensorflow as tf
from tensorflow import keras


class CTCLoss(keras.losses.Loss):
    """A class that wraps the function of tf.nn.ctc_loss.

    Attributes:
        logits_time_major: If False (default) , shape is [batch, time, logits],
            If True, logits is shaped [time, batch, logits].
        blank_index: Set the class index to use for the blank label. default is
            -1 (num_classes - 1).
    """

    def __init__(
        self, logits_time_major=False, blank_index=-1, name="ctc_loss"
    ):
        super().__init__(name=name)
        self.logits_time_major = logits_time_major
        self.blank_index = blank_index

    def call(self, y_true, y_pred):
        """Computes CTC (Connectionist Temporal Classification) loss. work on
        CPU, because y_true is a SparseTensor.
        """
        y_true = tf.cast(y_true, tf.int32)
        y_pred_shape = tf.shape(y_pred)
        logit_length = tf.fill([y_pred_shape[0]], y_pred_shape[1])
        loss = tf.nn.ctc_loss(
            labels=y_true,
            logits=y_pred,
            label_length=None,
            logit_length=logit_length,
            logits_time_major=self.logits_time_major,
            blank_index=self.blank_index,
        )
        return tf.math.reduce_mean(loss)


================================================
FILE: crnn/metrics.py
================================================
import tensorflow as tf
from tensorflow import keras


class SequenceAccuracy(keras.metrics.Metric):
    def __init__(self, name="sequence_accuracy", **kwargs):
        super().__init__(name=name, **kwargs)
        self.total = self.add_weight(name="total", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        def sparse2dense(tensor, shape):
            tensor = tf.sparse.reset_shape(tensor, shape)
            tensor = tf.sparse.to_dense(tensor, default_value=-1)
            tensor = tf.cast(tensor, tf.float32)
            return tensor

        y_true_shape = tf.shape(y_true)
        batch_size = y_true_shape[0]
        y_pred_shape = tf.shape(y_pred)
        max_width = tf.math.maximum(y_true_shape[1], y_pred_shape[1])
        logit_length = tf.fill([batch_size], y_pred_shape[1])
        decoded, _ = tf.nn.ctc_greedy_decoder(
            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
            sequence_length=logit_length,
        )
        y_true = sparse2dense(y_true, [batch_size, max_width])
        y_pred = sparse2dense(decoded[0], [batch_size, max_width])
        num_errors = tf.math.reduce_any(
            tf.math.not_equal(y_true, y_pred), axis=1
        )
        num_errors = tf.cast(num_errors, tf.float32)
        num_errors = tf.math.reduce_sum(num_errors)
        batch_size = tf.cast(batch_size, tf.float32)
        self.total.assign_add(batch_size)
        self.count.assign_add(batch_size - num_errors)

    def result(self):
        return self.count / self.total

    def reset_states(self):
        self.count.assign(0)
        self.total.assign(0)


class EditDistance(keras.metrics.Metric):
    def __init__(self, name="edit_distance", **kwargs):
        super().__init__(name=name, **kwargs)
        self.total = self.add_weight(name="total", initializer="zeros")
        self.sum_distance = self.add_weight(
            name="sum_distance", initializer="zeros"
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_shape = tf.shape(y_pred)
        batch_size = y_pred_shape[0]
        logit_length = tf.fill([batch_size], y_pred_shape[1])
        decoded, _ = tf.nn.ctc_greedy_decoder(
            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
            sequence_length=logit_length,
        )
        sum_distance = tf.math.reduce_sum(tf.edit_distance(decoded[0], y_true))
        batch_size = tf.cast(batch_size, tf.float32)
        self.sum_distance.assign_add(sum_distance)
        self.total.assign_add(batch_size)

    def result(self):
        return self.sum_distance / self.total

    def reset_states(self):
        self.sum_distance.assign(0)
        self.total.assign(0)


================================================
FILE: crnn/models.py
================================================
from tensorflow import keras
from tensorflow.keras import layers


def vgg_style(x):
    """
    The original feature extraction structure from CRNN paper.
    Related paper: https://ieeexplore.ieee.org/abstract/document/7801919
    """
    x = layers.Conv2D(64, 3, padding="same", activation="relu", name="conv1")(x)
    x = layers.MaxPool2D(pool_size=2, padding="same", name="pool1")(x)

    x = layers.Conv2D(128, 3, padding="same", activation="relu", name="conv2")(
        x
    )
    x = layers.MaxPool2D(pool_size=2, padding="same", name="pool2")(x)

    x = layers.Conv2D(256, 3, padding="same", use_bias=False, name="conv3")(x)
    x = layers.BatchNormalization(name="bn3")(x)
    x = layers.Activation("relu", name="relu3")(x)
    x = layers.Conv2D(256, 3, padding="same", activation="relu", name="conv4")(
        x
    )
    x = layers.MaxPool2D(
        pool_size=2, strides=(2, 1), padding="same", name="pool4"
    )(x)

    x = layers.Conv2D(512, 3, padding="same", use_bias=False, name="conv5")(x)
    x = layers.BatchNormalization(name="bn5")(x)
    x = layers.Activation("relu", name="relu5")(x)
    x = layers.Conv2D(512, 3, padding="same", activation="relu", name="conv6")(
        x
    )
    x = layers.MaxPool2D(
        pool_size=2, strides=(2, 1), padding="same", name="pool6"
    )(x)

    x = layers.Conv2D(512, 2, use_bias=False, name="conv7")(x)
    x = layers.BatchNormalization(name="bn7")(x)
    x = layers.Activation("relu", name="relu7")(x)

    x = layers.Reshape((-1, 512), name="reshape7")(x)
    return x


def build_model(
    num_classes,
    weight=None,
    preprocess=None,
    postprocess=None,
    img_shape=(32, None, 3),
    model_name="crnn",
):
    x = img_input = keras.Input(shape=img_shape)
    if preprocess is not None:
        x = preprocess(x)

    x = vgg_style(x)
    x = layers.Bidirectional(
        layers.LSTM(units=256, return_sequences=True), name="bi_lstm1"
    )(x)
    x = layers.Bidirectional(
        layers.LSTM(units=256, return_sequences=True), name="bi_lstm2"
    )(x)
    x = layers.Dense(units=num_classes, name="logits")(x)

    if postprocess is not None:
        x = postprocess(x)

    model = keras.Model(inputs=img_input, outputs=x, name=model_name)
    if weight is not None:
        model.load_weights(weight, by_name=True, skip_mismatch=True)
    return model


================================================
FILE: crnn/train.py
================================================
import argparse
import pprint
import shutil
from pathlib import Path

import tensorflow as tf
import yaml
from tensorflow import keras

from dataset_factory import DatasetBuilder
from losses import CTCLoss
from metrics import SequenceAccuracy
from models import build_model

parser = argparse.ArgumentParser()
parser.add_argument(
    "--config", type=Path, required=True, help="The config file path."
)
parser.add_argument(
    "--save_dir",
    type=Path,
    required=True,
    help="The path to save the models, logs, etc.",
)
args = parser.parse_args()

with args.config.open() as f:
    config = yaml.load(f, Loader=yaml.Loader)["train"]
pprint.pprint(config)

args.save_dir.mkdir(exist_ok=True)
if list(args.save_dir.iterdir()):
    raise ValueError(f"{args.save_dir} is not a empty folder")
shutil.copy(args.config, args.save_dir / args.config.name)

strategy = tf.distribute.MirroredStrategy()
batch_size = config["batch_size_per_replica"] * strategy.num_replicas_in_sync

dataset_builder = DatasetBuilder(**config["dataset_builder"])
train_ds = dataset_builder(config["train_ann_paths"], batch_size, True)
val_ds = dataset_builder(config["val_ann_paths"], batch_size, False)

with strategy.scope():
    lr_schedule = keras.optimizers.schedules.CosineDecay(
        **config["lr_schedule"]
    )
    model = build_model(
        dataset_builder.num_classes,
        weight=config.get("weight"),
        img_shape=config["dataset_builder"]["img_shape"],
    )
    model.compile(
        optimizer=keras.optimizers.Adam(lr_schedule),
        loss=CTCLoss(),
        metrics=[SequenceAccuracy()],
    )

model.summary()

model_prefix = "{epoch}_{val_loss:.4f}_{val_sequence_accuracy:.4f}"
model_path = f"{args.save_dir}/{model_prefix}.h5"
callbacks = [
    keras.callbacks.ModelCheckpoint(model_path, save_weights_only=True),
    keras.callbacks.TensorBoard(
        log_dir=f"{args.save_dir}/logs", **config["tensorboard"]
    ),
]

model.fit(
    train_ds,
    epochs=config["epochs"],
    callbacks=callbacks,
    validation_data=val_ds,
)


================================================
FILE: example/icdar2013_annotation.txt
================================================
images/word_1.png, "Tiredness"
images/word_2.png, "kills"
images/word_3.png, "A"

================================================
FILE: example/mjsynth_annotation.txt
================================================
images/1_Paintbrushes_55044.jpg 55044
images/2_Reimbursing_64165.jpg 64165
images/3_Creationisms_17934.jpg 17934

================================================
FILE: example/simple_annotation.txt
================================================
images/1_Paintbrushes_55044.jpg Paintbrushes
images/2_Reimbursing_64165.jpg Reimbursing
images/word_1.png Tiredness
images/word_2.png kills

================================================
FILE: example/table.txt
================================================
<UNK>
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
<BLK>

================================================
FILE: pyproject.toml
================================================
[tool.black]
line-length = 80
target-version = ['py37']
include = '\.pyi?$'
# 'extend-exclude' excludes files or directories in addition to the defaults
extend-exclude = '''
# A regex preceded with ^/ will apply only to files and directories
# in the root of the project.
'''

[tool.isort]
profile = "black"
src_paths = ["crnn"]
line_length = 80
force_single_line = true


================================================
FILE: requirements.txt
================================================
pyyaml
tensorflow >= 2.3

================================================
FILE: requirements_dev.txt
================================================
-r requirements.txt

pre-commit


================================================
FILE: tools/demo.py
================================================
import argparse
from pathlib import Path

import tensorflow as tf
import yaml
from tensorflow import keras

parser = argparse.ArgumentParser()
parser.add_argument(
    "--images", type=str, required=True, help="Image file or folder path."
)
parser.add_argument(
    "--config", type=Path, required=True, help="The config file path."
)
parser.add_argument("--model", type=str, required=True, help="The saved model.")
args = parser.parse_args()

with args.config.open() as f:
    config = yaml.load(f, Loader=yaml.Loader)["dataset_builder"]


def read_img_and_resize(path, shape):
    img = tf.io.read_file(path)
    img = tf.io.decode_jpeg(img, channels=shape[2])
    if shape[1] is None:
        img_shape = tf.shape(img)
        scale_factor = shape[0] / img_shape[0]
        img_width = scale_factor * tf.cast(img_shape[1], tf.float64)
        img_width = tf.cast(img_width, tf.int32)
    else:
        img_width = shape[1]
    img = tf.image.resize(img, (shape[0], img_width))
    return img


model = keras.models.load_model(args.model, compile=False)

p = Path(args.images)
img_paths = p.iterdir() if p.is_dir() else [p]
for img_path in img_paths:
    img = read_img_and_resize(str(img_path), config["img_shape"])
    img = tf.expand_dims(img, 0)
    outputs = model(img)
    print(
        f"Path: {img_path}, y_pred: {outputs[0].numpy()}, "
        f"probability: {outputs[1].numpy()}"
    )
Download .txt
gitextract_9oovgxx4/

├── .dockerignore
├── .gitignore
├── .pre-commit-config.yaml
├── Changelog
├── Dockerfile
├── LICENSE
├── README.md
├── crnn/
│   ├── dataset_factory.py
│   ├── decoders.py
│   ├── eval.py
│   ├── export.py
│   ├── losses.py
│   ├── metrics.py
│   ├── models.py
│   └── train.py
├── example/
│   ├── icdar2013_annotation.txt
│   ├── mjsynth_annotation.txt
│   ├── simple_annotation.txt
│   └── table.txt
├── pyproject.toml
├── requirements.txt
├── requirements_dev.txt
└── tools/
    └── demo.py
Download .txt
SYMBOL INDEX (44 symbols across 6 files)

FILE: crnn/dataset_factory.py
  class Dataset (line 13) | class Dataset(tf.data.TextLineDataset):
    method __init__ (line 14) | def __init__(self, filename, **kwargs):
    method parse_func (line 18) | def parse_func(self, line):
    method parse_line (line 21) | def parse_line(self, line):
  class SimpleDataset (line 28) | class SimpleDataset(Dataset):
    method parse_func (line 29) | def parse_func(self, line):
  class MJSynthDataset (line 35) | class MJSynthDataset(Dataset):
    method parse_func (line 36) | def parse_func(self, line):
  class ICDARDataset (line 43) | class ICDARDataset(Dataset):
    method parse_func (line 44) | def parse_func(self, line):
  class DatasetBuilder (line 52) | class DatasetBuilder:
    method __init__ (line 53) | def __init__(
    method num_classes (line 80) | def num_classes(self):
    method _parse_annotation (line 83) | def _parse_annotation(self, path):
    method _concatenate_ds (line 95) | def _concatenate_ds(self, ann_paths):
    method _decode_img (line 103) | def _decode_img(self, filename, label):
    method _filter_img (line 116) | def _filter_img(self, img, label):
    method _tokenize (line 120) | def _tokenize(self, imgs, labels):
    method __call__ (line 127) | def __call__(self, ann_paths, batch_size, is_training):

FILE: crnn/decoders.py
  class CTCDecoder (line 5) | class CTCDecoder(keras.layers.Layer):
    method __init__ (line 6) | def __init__(self, table_path, **kwargs):
    method detokenize (line 19) | def detokenize(self, x):
  class CTCGreedyDecoder (line 26) | class CTCGreedyDecoder(CTCDecoder):
    method __init__ (line 27) | def __init__(self, table_path, merge_repeated=True, **kwargs):
    method call (line 31) | def call(self, inputs):
  class CTCBeamSearchDecoder (line 53) | class CTCBeamSearchDecoder(CTCDecoder):
    method __init__ (line 54) | def __init__(self, table_path, beam_width=100, top_paths=1, **kwargs):
    method call (line 59) | def call(self, inputs):

FILE: crnn/losses.py
  class CTCLoss (line 5) | class CTCLoss(keras.losses.Loss):
    method __init__ (line 15) | def __init__(
    method call (line 22) | def call(self, y_true, y_pred):

FILE: crnn/metrics.py
  class SequenceAccuracy (line 5) | class SequenceAccuracy(keras.metrics.Metric):
    method __init__ (line 6) | def __init__(self, name="sequence_accuracy", **kwargs):
    method update_state (line 11) | def update_state(self, y_true, y_pred, sample_weight=None):
    method result (line 38) | def result(self):
    method reset_states (line 41) | def reset_states(self):
  class EditDistance (line 46) | class EditDistance(keras.metrics.Metric):
    method __init__ (line 47) | def __init__(self, name="edit_distance", **kwargs):
    method update_state (line 54) | def update_state(self, y_true, y_pred, sample_weight=None):
    method result (line 67) | def result(self):
    method reset_states (line 70) | def reset_states(self):

FILE: crnn/models.py
  function vgg_style (line 5) | def vgg_style(x):
  function build_model (line 46) | def build_model(

FILE: tools/demo.py
  function read_img_and_resize (line 22) | def read_img_and_resize(path, shape):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (30K chars).
[
  {
    "path": ".dockerignore",
    "chars": 71,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n\n.vscode/\nmodels/\n"
  },
  {
    "path": ".gitignore",
    "chars": 80,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n\n.vscode/\nconfigs/\nmodels/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 385,
    "preview": "repos:\n    - repo: https://github.com/pre-commit/pre-commit-hooks\n      rev: v2.3.0\n      hooks:\n          - id: check-y"
  },
  {
    "path": "Changelog",
    "chars": 938,
    "preview": "# Changelog\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changel"
  },
  {
    "path": "Dockerfile",
    "chars": 129,
    "preview": "FROM tensorflow/tensorflow:latest-gpu\n\nWORKDIR /workspace\n\nCOPY . /workspace\n\nRUN pip install --no-cache-dir -r requirem"
  },
  {
    "path": "LICENSE",
    "chars": 1074,
    "preview": "MIT License\n\nCopyright (c) 2019-2021 Huang Yiming\n\nPermission is hereby granted, free of charge, to any person obtaining"
  },
  {
    "path": "README.md",
    "chars": 5102,
    "preview": "# Convolutional Recurrent Neural Network for End-to-End Text Recognition - TensorFlow 2\n\n![TensorFlow version](https://i"
  },
  {
    "path": "crnn/dataset_factory.py",
    "chars": 4801,
    "preview": "import os\nimport re\n\nimport tensorflow as tf\n\ntry:\n    AUTOTUNE = tf.data.AUTOTUNE\nexcept AttributeError:\n    # tf < 2.4"
  },
  {
    "path": "crnn/decoders.py",
    "chars": 2404,
    "preview": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass CTCDecoder(keras.layers.Layer):\n    def __init__(self, tabl"
  },
  {
    "path": "crnn/eval.py",
    "chars": 934,
    "preview": "import argparse\nimport pprint\n\nimport yaml\n\nfrom dataset_factory import DatasetBuilder\nfrom losses import CTCLoss\nfrom m"
  },
  {
    "path": "crnn/export.py",
    "chars": 1469,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport yaml\nfrom tensorflow import keras\n\nfrom decoders import CTCBeamSearchDe"
  },
  {
    "path": "crnn/losses.py",
    "chars": 1276,
    "preview": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass CTCLoss(keras.losses.Loss):\n    \"\"\"A class that wraps the f"
  },
  {
    "path": "crnn/metrics.py",
    "chars": 2776,
    "preview": "import tensorflow as tf\nfrom tensorflow import keras\n\n\nclass SequenceAccuracy(keras.metrics.Metric):\n    def __init__(se"
  },
  {
    "path": "crnn/models.py",
    "chars": 2344,
    "preview": "from tensorflow import keras\nfrom tensorflow.keras import layers\n\n\ndef vgg_style(x):\n    \"\"\"\n    The original feature ex"
  },
  {
    "path": "crnn/train.py",
    "chars": 2049,
    "preview": "import argparse\nimport pprint\nimport shutil\nfrom pathlib import Path\n\nimport tensorflow as tf\nimport yaml\nfrom tensorflo"
  },
  {
    "path": "example/icdar2013_annotation.txt",
    "chars": 80,
    "preview": "images/word_1.png, \"Tiredness\"\nimages/word_2.png, \"kills\"\nimages/word_3.png, \"A\""
  },
  {
    "path": "example/mjsynth_annotation.txt",
    "chars": 112,
    "preview": "images/1_Paintbrushes_55044.jpg 55044\nimages/2_Reimbursing_64165.jpg 64165\nimages/3_Creationisms_17934.jpg 17934"
  },
  {
    "path": "example/simple_annotation.txt",
    "chars": 139,
    "preview": "images/1_Paintbrushes_55044.jpg Paintbrushes\nimages/2_Reimbursing_64165.jpg Reimbursing\nimages/word_1.png Tiredness\nimag"
  },
  {
    "path": "example/table.txt",
    "chars": 83,
    "preview": "<UNK>\n0\n1\n2\n3\n4\n5\n6\n7\n8\n9\na\nb\nc\nd\ne\nf\ng\nh\ni\nj\nk\nl\nm\nn\no\np\nq\nr\ns\nt\nu\nv\nw\nx\ny\nz\n<BLK>"
  },
  {
    "path": "pyproject.toml",
    "chars": 371,
    "preview": "[tool.black]\nline-length = 80\ntarget-version = ['py37']\ninclude = '\\.pyi?$'\n# 'extend-exclude' excludes files or directo"
  },
  {
    "path": "requirements.txt",
    "chars": 24,
    "preview": "pyyaml\ntensorflow >= 2.3"
  },
  {
    "path": "requirements_dev.txt",
    "chars": 32,
    "preview": "-r requirements.txt\n\npre-commit\n"
  },
  {
    "path": "tools/demo.py",
    "chars": 1399,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport tensorflow as tf\nimport yaml\nfrom tensorflow import keras\n\nparser = arg"
  }
]

About this extraction

This page contains the full source code of the FLming/CRNN.tf2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (27.4 KB), approximately 7.6k tokens, and a symbol index with 44 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!