Repository: FLming/CRNN.tf2 Branch: main Commit: ccf2a7a7a21f Files: 23 Total size: 27.4 KB Directory structure: gitextract_9oovgxx4/ ├── .dockerignore ├── .gitignore ├── .pre-commit-config.yaml ├── Changelog ├── Dockerfile ├── LICENSE ├── README.md ├── crnn/ │ ├── dataset_factory.py │ ├── decoders.py │ ├── eval.py │ ├── export.py │ ├── losses.py │ ├── metrics.py │ ├── models.py │ └── train.py ├── example/ │ ├── icdar2013_annotation.txt │ ├── mjsynth_annotation.txt │ ├── simple_annotation.txt │ └── table.txt ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt └── tools/ └── demo.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ .vscode/ models/ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ .vscode/ configs/ models/ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/psf/black rev: 22.10.0 hooks: - id: black - repo: https://github.com/pycqa/isort rev: 5.12.0 hooks: - id: isort ================================================ FILE: Changelog ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] ### Add - Use RaggedTensor for train and eval pipeline ## [0.2.0] - 2021-07-29 ### Changed - New export script - New model build way ## [0.1.1] - 2021-01-22 ### Changed - Reduce minimum TensorFlow version to 2.2 - Better post processing add way - Use StaticHashTable instead of StringLookup layer ## [0.1.0] - 2021-01-09 ### Added - Add label - Add docker support - Add EditDistance metrics - Add Decoders for a truly end-to-end model [experimental] ### Changed - Update minimum TensorFlow version to 2.3.0 - Change img_height, img_width, img_channels to img_shape - Build a new data pipeline - Use new preprocessing - Change model_dir to save_dir in train.py ================================================ FILE: Dockerfile ================================================ FROM tensorflow/tensorflow:latest-gpu WORKDIR /workspace COPY . /workspace RUN pip install --no-cache-dir -r requirements.txt ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019-2021 Huang Yiming Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Convolutional Recurrent Neural Network for End-to-End Text Recognition - TensorFlow 2 ![TensorFlow version](https://img.shields.io/badge/TensorFlow->=2.3-FF6F00?logo=tensorflow) ![Python version](https://img.shields.io/badge/Python->=3.6-3776AB?logo=python) [![Paper](https://img.shields.io/badge/paper-arXiv:1507.05717-B3181B?logo=arXiv)](https://arxiv.org/abs/1507.05717) [![Zhihu](https://img.shields.io/badge/知乎-文本识别网络CRNN—实现简述-blue?logo=zhihu)](https://zhuanlan.zhihu.com/p/122512498) This is a re-implementation of the CRNN network, build by TensorFlow 2. This repository may help you to understand how to build an End-to-End text recognition network easily. Here is the official [repo](https://github.com/bgshih/crnn) implemented by [bgshih](https://github.com/bgshih). ## Abstract This repo aims to build a simple, efficient text recognize network by using the various components of TensorFlow 2. The model build by the Keras API, the data pipeline build by `tf.data`, and training with `model.fit`, so we can use most of the functions provided by TensorFlow 2, such as `Tensorboard`, `Distribution strategy`, `TensorFlow Profiler` etc. ## Installation ```bash $ pip install -r requirements.txt ``` ## Demo Here I provide an example model that trained on the Mjsynth dataset, this model can only predict 0-9 and a-z(ignore case). ```bash $ wget https://github.com/FLming/CRNN.tf2/releases/download/v0.2.0/SavedModel.tgz $ tar xzvf SavedModel.tgz $ python tools/demo.py --images example/images/ --config configs/mjsynth.yml --model SavedModel ``` Then, You will see output like this: ``` Path: example/images/word_1.png, y_pred: [b'tiredness'], probability: [0.9998626] Path: example/images/word_3.png, y_pred: [b'a'], probability: [0.67493004] Path: example/images/2_Reimbursing_64165.jpg, y_pred: [b'reimbursing'], probability: [0.990946] Path: example/images/word_2.png, y_pred: [b'kills'], probability: [0.9994573] Path: example/images/1_Paintbrushes_55044.jpg, y_pred: [b'paintbrushes'], probability: [0.9984008] Path: example/images/3_Creationisms_17934.jpg, y_pred: [b'creationisms'], probability: [0.99792457] ``` About decode methods, sometimes the beam search method will be better than the greedy method, but it's costly. ## Train Before you start training, maybe you should [prepare](#Data-prepare) data first. All predictable characters are defined by the [table.txt](example/table.txt) file. The configuration of the training process is defined by the [yml](configs/mjsynth.yml) file. This training script uses all GPUs by default, if you want to use a specific GPU, please set the `CUDA_VISIBLE_DEVICES` parameter. ```bash $ python crnn/train.py --config configs/mjsynth.yml --save_dir PATH/TO/SAVE ``` The training process can visualize in Tensorboard. ```bash $ tensorboard --logdir PATH/TO/MODEL_DIR ``` For more instructions, please refer to the [config](configs/mjsynth.yml) file. ## Data prepare To train this network, you should prepare a lookup table, images and corresponding labels. Example data is copy from [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [ICDAR2013](https://rrc.cvc.uab.es/?ch=2&com=introduction) dataset. ### [Lookup table](./example/table.txt) The file contains all characters and blank labels (in the last or any place both ok, but I find Tensorflow decoders can't change it now, so set it to last). By the way, you can write any word as blank. ### Image data It's an End-to-End method, so we don't need to indicate the position of the character in the image. ![Paintbrushes](example/images/1_Paintbrushes_55044.jpg) ![Creationisms](example/images/3_Creationisms_17934.jpg) ![Reimbursing](example/images/2_Reimbursing_64165.jpg) The labels corresponding to these three pictures are `Paintbrushes`, `Creationisms`, `Reimbursing`. ### Annotation file We should write the image path and its corresponding label to a text file in a certain format such as example data. The data input pipeline will automatically detect the support format. Customization is also very simple, please check out the [dataset factory](crnn/dataset_factory.py). #### Support format - [MJSynth](./example/mjsynth_annotation.txt) - [ICDAR2013/2015](./example/icdar2013_annotation.txt) - [Simple](./example/simple_annotation.txt) such as [example.jpg label] ## Eval ```bash $ python crnn/eval.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT ``` ## Converte & Ecosystem There are many components here to help us do other things. For example, deploy by `Tensorflow serving`. Before you deploy, you can pick up a good weight, and convertes model to `SavedModel` format by this command, it will add the post processing layer in the last and cull the optimizer: ```bash $ python tools/export.py --config PATH/TO/CONFIG_FILE --weight PATH/TO/MODEL_WEIGHT --pre rescale --post greedy --output PATH/TO/OUTPUT ``` And now `Tensorflow lite` also can convert this model, that means you can deploy it to Android, iOS etc. Note. Decoders can't convert to `Tensorflow lite` because of the assets. Use the softmax layer or None. ================================================ FILE: crnn/dataset_factory.py ================================================ import os import re import tensorflow as tf try: AUTOTUNE = tf.data.AUTOTUNE except AttributeError: # tf < 2.4.0 AUTOTUNE = tf.data.experimental.AUTOTUNE class Dataset(tf.data.TextLineDataset): def __init__(self, filename, **kwargs): self.dirname = os.path.dirname(filename) super().__init__(filename, **kwargs) def parse_func(self, line): raise NotImplementedError def parse_line(self, line): line = tf.strings.strip(line) img_relative_path, label = self.parse_func(line) img_path = tf.strings.join([self.dirname, os.sep, img_relative_path]) return img_path, label class SimpleDataset(Dataset): def parse_func(self, line): splited_line = tf.strings.split(line) img_relative_path, label = splited_line[0], splited_line[1] return img_relative_path, label class MJSynthDataset(Dataset): def parse_func(self, line): splited_line = tf.strings.split(line) img_relative_path = splited_line[0] label = tf.strings.split(img_relative_path, sep="_")[1] return img_relative_path, label class ICDARDataset(Dataset): def parse_func(self, line): splited_line = tf.strings.split(line, sep=",") img_relative_path, label = splited_line[0], splited_line[1] label = tf.strings.strip(label) label = tf.strings.regex_replace(label, r'"', "") return img_relative_path, label class DatasetBuilder: def __init__( self, table_path, img_shape=(32, None, 3), max_img_width=300, ignore_case=False, ): # map unknown label to 0 self.table = tf.lookup.StaticHashTable( tf.lookup.TextFileInitializer( table_path, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, ), 0, ) self.img_shape = img_shape self.ignore_case = ignore_case if img_shape[1] is None: self.max_img_width = max_img_width self.preserve_aspect_ratio = True else: self.preserve_aspect_ratio = False @property def num_classes(self): return self.table.size() def _parse_annotation(self, path): with open(path) as f: line = f.readline().strip() if re.fullmatch(r".*/*\d+_.+_(\d+)\.\w+ \1", line): return MJSynthDataset(path) elif re.fullmatch(r'.*/*word_\d\.\w+, ".+"', line): return ICDARDataset(path) elif re.fullmatch(r".+\.\w+ .+", line): return SimpleDataset(path) else: raise ValueError("Unsupported annotation format") def _concatenate_ds(self, ann_paths): datasets = [self._parse_annotation(path) for path in ann_paths] concatenated_ds = datasets[0].map(datasets[0].parse_line) for ds in datasets[1:]: ds = ds.map(ds.parse_line) concatenated_ds = concatenated_ds.concatenate(ds) return concatenated_ds def _decode_img(self, filename, label): img = tf.io.read_file(filename) img = tf.io.decode_jpeg(img, channels=self.img_shape[-1]) if self.preserve_aspect_ratio: img_shape = tf.shape(img) scale_factor = self.img_shape[0] / img_shape[0] img_width = scale_factor * tf.cast(img_shape[1], tf.float64) img_width = tf.cast(img_width, tf.int32) else: img_width = self.img_shape[1] img = tf.image.resize(img, (self.img_shape[0], img_width)) / 255.0 return img, label def _filter_img(self, img, label): img_shape = tf.shape(img) return img_shape[1] < self.max_img_width def _tokenize(self, imgs, labels): chars = tf.strings.unicode_split(labels, "UTF-8") tokens = tf.ragged.map_flat_values(self.table.lookup, chars) # TODO(hym) Waiting for official support to use RaggedTensor in keras tokens = tokens.to_sparse() return imgs, tokens def __call__(self, ann_paths, batch_size, is_training): ds = self._concatenate_ds(ann_paths) if self.ignore_case: ds = ds.map(lambda x, y: (x, tf.strings.lower(y))) if is_training: ds = ds.shuffle(buffer_size=10000) ds = ds.map(self._decode_img, AUTOTUNE) if self.preserve_aspect_ratio and batch_size != 1: ds = ds.filter(self._filter_img) ds = ds.padded_batch(batch_size, drop_remainder=is_training) else: ds = ds.batch(batch_size, drop_remainder=is_training) ds = ds.map(self._tokenize, AUTOTUNE) ds = ds.prefetch(AUTOTUNE) return ds ================================================ FILE: crnn/decoders.py ================================================ import tensorflow as tf from tensorflow import keras class CTCDecoder(keras.layers.Layer): def __init__(self, table_path, **kwargs): super().__init__(**kwargs) self.table = tf.lookup.StaticHashTable( tf.lookup.TextFileInitializer( table_path, tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, ), "", ) def detokenize(self, x): x = tf.RaggedTensor.from_sparse(x) x = tf.ragged.map_flat_values(self.table.lookup, x) strings = tf.strings.reduce_join(x, axis=1) return strings class CTCGreedyDecoder(CTCDecoder): def __init__(self, table_path, merge_repeated=True, **kwargs): super().__init__(table_path, **kwargs) self.merge_repeated = merge_repeated def call(self, inputs): input_shape = tf.shape(inputs) sequence_length = tf.fill([input_shape[0]], input_shape[1]) decoded, neg_sum_logits = tf.nn.ctc_greedy_decoder( tf.transpose(inputs, perm=[1, 0, 2]), sequence_length, self.merge_repeated, ) strings = self.detokenize(decoded[0]) labels = tf.cast(decoded[0], tf.int32) loss = tf.nn.ctc_loss( labels=labels, logits=inputs, label_length=None, logit_length=sequence_length, logits_time_major=False, blank_index=-1, ) probability = tf.math.exp(-loss) return strings, probability class CTCBeamSearchDecoder(CTCDecoder): def __init__(self, table_path, beam_width=100, top_paths=1, **kwargs): super().__init__(table_path, **kwargs) self.beam_width = beam_width self.top_paths = top_paths def call(self, inputs): input_shape = tf.shape(inputs) decoded, log_probability = tf.nn.ctc_beam_search_decoder( tf.transpose(inputs, perm=[1, 0, 2]), tf.fill([input_shape[0]], input_shape[1]), self.beam_width, self.top_paths, ) strings = [] for i in range(self.top_paths): strings.append(self.detokenize(decoded[i])) strings = tf.concat(strings, 1) probability = tf.math.exp(log_probability) return strings, probability ================================================ FILE: crnn/eval.py ================================================ import argparse import pprint import yaml from dataset_factory import DatasetBuilder from losses import CTCLoss from metrics import EditDistance from metrics import SequenceAccuracy from models import build_model parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="The config file path." ) parser.add_argument( "--weight", type=str, required=True, help="The saved weight path." ) args = parser.parse_args() with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader)["eval"] pprint.pprint(config) dataset_builder = DatasetBuilder(**config["dataset_builder"]) ds = dataset_builder(config["ann_paths"], config["batch_size"], False) model = build_model( dataset_builder.num_classes, weight=args.weight, img_shape=config["dataset_builder"]["img_shape"], ) model.compile(loss=CTCLoss(), metrics=[SequenceAccuracy(), EditDistance()]) model.evaluate(ds) ================================================ FILE: crnn/export.py ================================================ import argparse from pathlib import Path import yaml from tensorflow import keras from decoders import CTCBeamSearchDecoder from decoders import CTCGreedyDecoder from models import build_model parser = argparse.ArgumentParser() parser.add_argument( "--config", type=Path, required=True, help="The config file path." ) parser.add_argument( "--weight", type=str, required=True, default="", help="The saved weight path.", ) parser.add_argument("--pre", type=str, help="pre processing.") parser.add_argument("--post", type=str, help="Post processing.") parser.add_argument( "--output", type=str, required=True, help="The output path." ) args = parser.parse_args() with args.config.open() as f: config = yaml.load(f, Loader=yaml.Loader)["dataset_builder"] with open(config["table_path"]) as f: num_classes = len(f.readlines()) if args.pre == "rescale": preprocess = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255) else: preprocess = None if args.post == "softmax": postprocess = keras.layers.Softmax() elif args.post == "greedy": postprocess = CTCGreedyDecoder(config["table_path"]) elif args.post == "beam_search": postprocess = CTCBeamSearchDecoder(config["table_path"]) else: postprocess = None model = build_model( num_classes, weight=args.weight, preprocess=preprocess, postprocess=postprocess, img_shape=config["img_shape"], ) model.summary() model.save(args.output) ================================================ FILE: crnn/losses.py ================================================ import tensorflow as tf from tensorflow import keras class CTCLoss(keras.losses.Loss): """A class that wraps the function of tf.nn.ctc_loss. Attributes: logits_time_major: If False (default) , shape is [batch, time, logits], If True, logits is shaped [time, batch, logits]. blank_index: Set the class index to use for the blank label. default is -1 (num_classes - 1). """ def __init__( self, logits_time_major=False, blank_index=-1, name="ctc_loss" ): super().__init__(name=name) self.logits_time_major = logits_time_major self.blank_index = blank_index def call(self, y_true, y_pred): """Computes CTC (Connectionist Temporal Classification) loss. work on CPU, because y_true is a SparseTensor. """ y_true = tf.cast(y_true, tf.int32) y_pred_shape = tf.shape(y_pred) logit_length = tf.fill([y_pred_shape[0]], y_pred_shape[1]) loss = tf.nn.ctc_loss( labels=y_true, logits=y_pred, label_length=None, logit_length=logit_length, logits_time_major=self.logits_time_major, blank_index=self.blank_index, ) return tf.math.reduce_mean(loss) ================================================ FILE: crnn/metrics.py ================================================ import tensorflow as tf from tensorflow import keras class SequenceAccuracy(keras.metrics.Metric): def __init__(self, name="sequence_accuracy", **kwargs): super().__init__(name=name, **kwargs) self.total = self.add_weight(name="total", initializer="zeros") self.count = self.add_weight(name="count", initializer="zeros") def update_state(self, y_true, y_pred, sample_weight=None): def sparse2dense(tensor, shape): tensor = tf.sparse.reset_shape(tensor, shape) tensor = tf.sparse.to_dense(tensor, default_value=-1) tensor = tf.cast(tensor, tf.float32) return tensor y_true_shape = tf.shape(y_true) batch_size = y_true_shape[0] y_pred_shape = tf.shape(y_pred) max_width = tf.math.maximum(y_true_shape[1], y_pred_shape[1]) logit_length = tf.fill([batch_size], y_pred_shape[1]) decoded, _ = tf.nn.ctc_greedy_decoder( inputs=tf.transpose(y_pred, perm=[1, 0, 2]), sequence_length=logit_length, ) y_true = sparse2dense(y_true, [batch_size, max_width]) y_pred = sparse2dense(decoded[0], [batch_size, max_width]) num_errors = tf.math.reduce_any( tf.math.not_equal(y_true, y_pred), axis=1 ) num_errors = tf.cast(num_errors, tf.float32) num_errors = tf.math.reduce_sum(num_errors) batch_size = tf.cast(batch_size, tf.float32) self.total.assign_add(batch_size) self.count.assign_add(batch_size - num_errors) def result(self): return self.count / self.total def reset_states(self): self.count.assign(0) self.total.assign(0) class EditDistance(keras.metrics.Metric): def __init__(self, name="edit_distance", **kwargs): super().__init__(name=name, **kwargs) self.total = self.add_weight(name="total", initializer="zeros") self.sum_distance = self.add_weight( name="sum_distance", initializer="zeros" ) def update_state(self, y_true, y_pred, sample_weight=None): y_pred_shape = tf.shape(y_pred) batch_size = y_pred_shape[0] logit_length = tf.fill([batch_size], y_pred_shape[1]) decoded, _ = tf.nn.ctc_greedy_decoder( inputs=tf.transpose(y_pred, perm=[1, 0, 2]), sequence_length=logit_length, ) sum_distance = tf.math.reduce_sum(tf.edit_distance(decoded[0], y_true)) batch_size = tf.cast(batch_size, tf.float32) self.sum_distance.assign_add(sum_distance) self.total.assign_add(batch_size) def result(self): return self.sum_distance / self.total def reset_states(self): self.sum_distance.assign(0) self.total.assign(0) ================================================ FILE: crnn/models.py ================================================ from tensorflow import keras from tensorflow.keras import layers def vgg_style(x): """ The original feature extraction structure from CRNN paper. Related paper: https://ieeexplore.ieee.org/abstract/document/7801919 """ x = layers.Conv2D(64, 3, padding="same", activation="relu", name="conv1")(x) x = layers.MaxPool2D(pool_size=2, padding="same", name="pool1")(x) x = layers.Conv2D(128, 3, padding="same", activation="relu", name="conv2")( x ) x = layers.MaxPool2D(pool_size=2, padding="same", name="pool2")(x) x = layers.Conv2D(256, 3, padding="same", use_bias=False, name="conv3")(x) x = layers.BatchNormalization(name="bn3")(x) x = layers.Activation("relu", name="relu3")(x) x = layers.Conv2D(256, 3, padding="same", activation="relu", name="conv4")( x ) x = layers.MaxPool2D( pool_size=2, strides=(2, 1), padding="same", name="pool4" )(x) x = layers.Conv2D(512, 3, padding="same", use_bias=False, name="conv5")(x) x = layers.BatchNormalization(name="bn5")(x) x = layers.Activation("relu", name="relu5")(x) x = layers.Conv2D(512, 3, padding="same", activation="relu", name="conv6")( x ) x = layers.MaxPool2D( pool_size=2, strides=(2, 1), padding="same", name="pool6" )(x) x = layers.Conv2D(512, 2, use_bias=False, name="conv7")(x) x = layers.BatchNormalization(name="bn7")(x) x = layers.Activation("relu", name="relu7")(x) x = layers.Reshape((-1, 512), name="reshape7")(x) return x def build_model( num_classes, weight=None, preprocess=None, postprocess=None, img_shape=(32, None, 3), model_name="crnn", ): x = img_input = keras.Input(shape=img_shape) if preprocess is not None: x = preprocess(x) x = vgg_style(x) x = layers.Bidirectional( layers.LSTM(units=256, return_sequences=True), name="bi_lstm1" )(x) x = layers.Bidirectional( layers.LSTM(units=256, return_sequences=True), name="bi_lstm2" )(x) x = layers.Dense(units=num_classes, name="logits")(x) if postprocess is not None: x = postprocess(x) model = keras.Model(inputs=img_input, outputs=x, name=model_name) if weight is not None: model.load_weights(weight, by_name=True, skip_mismatch=True) return model ================================================ FILE: crnn/train.py ================================================ import argparse import pprint import shutil from pathlib import Path import tensorflow as tf import yaml from tensorflow import keras from dataset_factory import DatasetBuilder from losses import CTCLoss from metrics import SequenceAccuracy from models import build_model parser = argparse.ArgumentParser() parser.add_argument( "--config", type=Path, required=True, help="The config file path." ) parser.add_argument( "--save_dir", type=Path, required=True, help="The path to save the models, logs, etc.", ) args = parser.parse_args() with args.config.open() as f: config = yaml.load(f, Loader=yaml.Loader)["train"] pprint.pprint(config) args.save_dir.mkdir(exist_ok=True) if list(args.save_dir.iterdir()): raise ValueError(f"{args.save_dir} is not a empty folder") shutil.copy(args.config, args.save_dir / args.config.name) strategy = tf.distribute.MirroredStrategy() batch_size = config["batch_size_per_replica"] * strategy.num_replicas_in_sync dataset_builder = DatasetBuilder(**config["dataset_builder"]) train_ds = dataset_builder(config["train_ann_paths"], batch_size, True) val_ds = dataset_builder(config["val_ann_paths"], batch_size, False) with strategy.scope(): lr_schedule = keras.optimizers.schedules.CosineDecay( **config["lr_schedule"] ) model = build_model( dataset_builder.num_classes, weight=config.get("weight"), img_shape=config["dataset_builder"]["img_shape"], ) model.compile( optimizer=keras.optimizers.Adam(lr_schedule), loss=CTCLoss(), metrics=[SequenceAccuracy()], ) model.summary() model_prefix = "{epoch}_{val_loss:.4f}_{val_sequence_accuracy:.4f}" model_path = f"{args.save_dir}/{model_prefix}.h5" callbacks = [ keras.callbacks.ModelCheckpoint(model_path, save_weights_only=True), keras.callbacks.TensorBoard( log_dir=f"{args.save_dir}/logs", **config["tensorboard"] ), ] model.fit( train_ds, epochs=config["epochs"], callbacks=callbacks, validation_data=val_ds, ) ================================================ FILE: example/icdar2013_annotation.txt ================================================ images/word_1.png, "Tiredness" images/word_2.png, "kills" images/word_3.png, "A" ================================================ FILE: example/mjsynth_annotation.txt ================================================ images/1_Paintbrushes_55044.jpg 55044 images/2_Reimbursing_64165.jpg 64165 images/3_Creationisms_17934.jpg 17934 ================================================ FILE: example/simple_annotation.txt ================================================ images/1_Paintbrushes_55044.jpg Paintbrushes images/2_Reimbursing_64165.jpg Reimbursing images/word_1.png Tiredness images/word_2.png kills ================================================ FILE: example/table.txt ================================================ 0 1 2 3 4 5 6 7 8 9 a b c d e f g h i j k l m n o p q r s t u v w x y z ================================================ FILE: pyproject.toml ================================================ [tool.black] line-length = 80 target-version = ['py37'] include = '\.pyi?$' # 'extend-exclude' excludes files or directories in addition to the defaults extend-exclude = ''' # A regex preceded with ^/ will apply only to files and directories # in the root of the project. ''' [tool.isort] profile = "black" src_paths = ["crnn"] line_length = 80 force_single_line = true ================================================ FILE: requirements.txt ================================================ pyyaml tensorflow >= 2.3 ================================================ FILE: requirements_dev.txt ================================================ -r requirements.txt pre-commit ================================================ FILE: tools/demo.py ================================================ import argparse from pathlib import Path import tensorflow as tf import yaml from tensorflow import keras parser = argparse.ArgumentParser() parser.add_argument( "--images", type=str, required=True, help="Image file or folder path." ) parser.add_argument( "--config", type=Path, required=True, help="The config file path." ) parser.add_argument("--model", type=str, required=True, help="The saved model.") args = parser.parse_args() with args.config.open() as f: config = yaml.load(f, Loader=yaml.Loader)["dataset_builder"] def read_img_and_resize(path, shape): img = tf.io.read_file(path) img = tf.io.decode_jpeg(img, channels=shape[2]) if shape[1] is None: img_shape = tf.shape(img) scale_factor = shape[0] / img_shape[0] img_width = scale_factor * tf.cast(img_shape[1], tf.float64) img_width = tf.cast(img_width, tf.int32) else: img_width = shape[1] img = tf.image.resize(img, (shape[0], img_width)) return img model = keras.models.load_model(args.model, compile=False) p = Path(args.images) img_paths = p.iterdir() if p.is_dir() else [p] for img_path in img_paths: img = read_img_and_resize(str(img_path), config["img_shape"]) img = tf.expand_dims(img, 0) outputs = model(img) print( f"Path: {img_path}, y_pred: {outputs[0].numpy()}, " f"probability: {outputs[1].numpy()}" )