.
etree.SubElement(attributes, 'divisions').text = '8'
# Key signature for the first staff.
key = etree.SubElement(attributes, 'key')
key.attrib['number'] = '1'
# Dummy value
key.text = 'C major'
# Key signature for the second staff.
key = etree.SubElement(attributes, 'key')
key.attrib['number'] = '2'
# Dummy value
key.text = 'G major'
part_staves = musicxml.PartStaves(score)
self.assertEqual(part_staves.num_partstaves(), 2)
staff_1_measure = part_staves.get_measure(0, 0)
self.assertEqual(len(staff_1_measure), 1)
attributes = staff_1_measure.find('attributes')
self.assertEqual(len(attributes), 2)
self.assertEqual(etree.tostring(attributes[0]), b'8 ')
self.assertEqual(etree.tostring(attributes[1]), b'C major ')
staff_2_measure = part_staves.get_measure(1, 0)
self.assertEqual(len(staff_2_measure), 1)
attributes = staff_2_measure.find('attributes')
self.assertEqual(len(attributes), 2)
self.assertEqual(etree.tostring(attributes[0]), b'8 ')
self.assertEqual(etree.tostring(attributes[1]), b'G major ')
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/glyphs/BUILD
================================================
# Description:
# Glyph classification for OMR.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "glyphs",
deps = [
":base",
":corpus",
":geometry",
":glyph_types",
":knn",
":knn_model",
":neural",
":note_dots",
":repeated",
":saved_classifier",
],
)
py_library(
name = "corpus",
srcs = ["corpus.py"],
srcs_version = "PY2AND3",
deps = [
# numpy dep
# tensorflow dep
],
)
py_library(
name = "base",
srcs = ["base.py"],
srcs_version = "PY2AND3",
deps = [
# enum34 dep
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
],
)
py_library(
name = "convolutional",
srcs = ["convolutional.py"],
srcs_version = "PY2AND3",
deps = [
":base",
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# tensorflow dep
],
)
py_test(
name = "convolutional_test",
srcs = ["convolutional_test.py"],
srcs_version = "PY2AND3",
deps = [
":base",
":convolutional",
":testing",
# disable_tf2
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# pandas dep
# tensorflow dep
],
)
py_library(
name = "knn",
srcs = ["knn.py"],
srcs_version = "PY2AND3",
deps = [
":convolutional",
":corpus",
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/util:patches",
# tensorflow dep
],
)
py_library(
name = "knn_model",
srcs = ["knn_model.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# tensorflow dep
],
)
py_test(
name = "knn_test",
srcs = ["knn_test.py"],
srcs_version = "PY2AND3",
deps = [
":base",
":knn",
# disable_tf2
"//moonlight/protobuf:protobuf_py_pb2",
# pandas dep
# tensorflow dep
],
)
py_library(
name = "geometry",
srcs = ["geometry.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "glyph_types",
srcs = ["glyph_types.py"],
srcs_version = "PY2AND3",
deps = ["//moonlight/protobuf:protobuf_py_pb2"],
)
py_library(
name = "neural",
srcs = ["neural.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/protobuf:protobuf_py_pb2",
# tensorflow dep
],
)
py_test(
name = "neural_test",
size = "small",
srcs = ["neural_test.py"],
srcs_version = "PY2AND3",
deps = [
":neural",
# disable_tf2
# numpy dep
# tensorflow dep
],
)
py_library(
name = "saved_classifier",
srcs = ["saved_classifier.py"],
srcs_version = "PY2AND3",
deps = [
":convolutional",
"//moonlight/staves:staffline_extractor",
"//moonlight/util:patches",
# tensorflow dep
# tensorflow.contrib.graph_editor py dep
# tensorflow.contrib.util py dep
],
)
py_test(
name = "saved_classifier_test",
srcs = ["saved_classifier_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":saved_classifier",
# disable_tf2
"//moonlight:image",
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/structure",
# tensorflow dep
],
)
py_library(
name = "saved_classifier_fn",
srcs = ["saved_classifier_fn.py"],
data = ["//moonlight/data/glyphs_nn_model_20180808"],
srcs_version = "PY2AND3",
deps = [":saved_classifier"],
)
py_library(
name = "note_dots",
srcs = ["note_dots.py"],
srcs_version = "PY2AND3",
deps = [
":geometry",
":glyph_types",
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/structure:components",
# numpy dep
],
)
py_library(
name = "repeated",
srcs = ["repeated.py"],
srcs_version = "PY2AND3",
deps = ["//moonlight/glyphs:glyph_types"],
)
py_library(
name = "testing",
testonly = True,
srcs = ["testing.py"],
srcs_version = "PY2AND3",
deps = [
":convolutional",
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# tensorflow dep
],
)
================================================
FILE: moonlight/glyphs/base.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base glyph classifier model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import enum
import numpy as np
from moonlight.protobuf import musicscore_pb2
class GlyphsTensorColumns(enum.IntEnum):
"""The columns of the glyphs tensors.
Glyphs should be held in a 2D tensor where the columns are the staff of the
glyph, the vertical position on the staff, x coordinate, and glyph type.
"""
STAFF_INDEX = 0
Y_POSITION = 1
X = 2
TYPE = 3
class BaseGlyphClassifier(object):
"""The base glyph classifier model."""
__metaclass__ = abc.ABCMeta
def __init__(self):
"""Base constructor for a glyph classifier.
Attributes:
staffline_extractor: Optional staffline extractor, if used for
classification. If present, classification uses the scaled stafflines,
and glyph x positions will be scaled back to page coordinates when
constructing the Page. If None, no scaling is done.
"""
self.staffline_extractor = None
@abc.abstractmethod
def get_detected_glyphs(self):
"""Detects glyphs in the image.
Each glyph belongs to a staff, and has a y position numbered from 0 for the
center staff line.
Returns:
A Tensor of glyphs, with shape (num_glyphs, 4). The columns are indexed by
`GlyphsTensorColumns`. The glyphs will be sorted later, so they may be
in any order.
"""
pass
def glyph_predictions_to_page(self, predictions):
"""Converts the glyph predictions to a Page message.
Args:
predictions: NumPy array which is equal to
`self.get_detected_glyphs().eval()` (but multiple tensors are evaluated
in a single run for efficiency.) Shape `(num_glyphs, 3)`.
Returns:
A `Page` message holding a single `StaffSystem`, with `Staff` messages
that only hold `Glyph`s. Structural information is added to the page by
`OMREngine`.
"""
num_staves = (
predictions[:, int(GlyphsTensorColumns.STAFF_INDEX)].max() +
1 if predictions.size else 0)
def create_glyph(glyph):
return musicscore_pb2.Glyph(
x=glyph[GlyphsTensorColumns.X],
y_position=glyph[GlyphsTensorColumns.Y_POSITION],
type=glyph[GlyphsTensorColumns.TYPE])
def generate_staff(staff_num):
glyphs = predictions[predictions[:,
int(GlyphsTensorColumns.STAFF_INDEX)] ==
staff_num]
# For determinism, sort glyphs by x, breaking ties by position (low to
# high).
glyph_order = np.lexsort(
glyphs[:, [GlyphsTensorColumns.Y_POSITION, GlyphsTensorColumns.X]].T)
glyphs = glyphs[glyph_order]
return musicscore_pb2.Staff(glyph=map(create_glyph, glyphs))
return musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(
staff=map(generate_staff, range(num_staves)))
])
================================================
FILE: moonlight/glyphs/convolutional.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base glyph classifier model."""
# TODO(ringw): Replace subclasses with a saved TF model. Hardcode the
# stafflines and predictions tensor names, so that we define the classifier
# separately. It can either be defined in the same graph before constructing the
# Convolutional1DGlyphClassifier, or loaded from a saved model after training
# externally.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import numpy as np
import tensorflow as tf
from moonlight.glyphs import base
from moonlight.protobuf import musicscore_pb2
DEFAULT_RUN_MIN_LENGTH = 3
class Convolutional1DGlyphClassifier(base.BaseGlyphClassifier):
"""The base 1D convolutional glyph classifier model."""
def __init__(self, run_min_length=DEFAULT_RUN_MIN_LENGTH):
"""Base classifier model.
Args:
run_min_length: Must have this many consecutive pixels with the same
non-NONE predicted glyph to emit the glyph.
"""
super(Convolutional1DGlyphClassifier, self).__init__()
self.run_min_length = run_min_length
@property
@abc.abstractmethod
def staffline_predictions(self):
"""The staffline predictions tensor.
Convolutional1DGlyphClassifier wraps this output, which would be the output
of a 1D convolutional model, and extracts individual glyphs to be added to
the Page message.
Shape (num_staves, num_stafflines, width).
"""
pass
def _build_detected_glyphs(self, predictions_arr):
"""Takes the convolutional output ndarray and builds the individual glyphs.
At each staff and y position, looks for short runs of the same detected
glyph, and then outputs a single glyph at the x coordinate of the center of
the run.
Args:
predictions_arr: A NumPy array with the result of `staffline_predictions`.
Shape (num_staves, num_stafflines, width).
Returns:
A 2D array of the glyph coordinates. Shape (num_glyphs, 4) with columns
corresponding to base.GlyphsTensorColumns.
"""
glyphs = []
num_staves, num_stafflines, width = predictions_arr.shape
for staff in range(num_staves):
for staffline in range(num_stafflines):
y_position = num_stafflines // 2 - staffline
run_start = -1
run_value = musicscore_pb2.Glyph.NONE
for x in range(width + 1):
if x < width:
value = predictions_arr[staff, staffline, x]
if x == width or value != run_value:
if run_value > musicscore_pb2.Glyph.NONE:
# Process the current run if it is at least run_min_length pixels.
if x - run_start >= self.run_min_length:
glyph_center_x = (run_start + x) // 2
glyphs.append(
self._create_glyph_arr(staff, y_position, glyph_center_x,
run_value))
run_value = value
run_start = x
# Convert to a 2D array.
glyphs = np.asarray(glyphs, np.int32)
return np.reshape(glyphs, (-1, 4))
def _create_glyph_arr(self, staff_index, y_position, x, type_value):
glyph = np.empty(len(base.GlyphsTensorColumns), np.int32)
glyph[base.GlyphsTensorColumns.STAFF_INDEX] = staff_index
glyph[base.GlyphsTensorColumns.Y_POSITION] = y_position
glyph[base.GlyphsTensorColumns.X] = x
glyph[base.GlyphsTensorColumns.TYPE] = type_value
return glyph
def get_detected_glyphs(self):
"""Extracts the individual glyphs as a Tensor.
This is run in the TensorFlow graph, so we have to wrap the Python glyph
logic in a `py_func`.
Returns:
A Tensor of glyphs, with shape (num_glyphs, 4). The columns are indexed by
`base.GlyphsTensorColumns`.
"""
return tf.py_func(self._build_detected_glyphs, [self.staffline_predictions],
tf.int32)
================================================
FILE: moonlight/glyphs/convolutional_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Convolutional1DGlyphClassifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import tensorflow as tf
from moonlight.glyphs import base
from moonlight.glyphs import convolutional
from moonlight.glyphs import testing
from moonlight.protobuf import musicscore_pb2
STAFF_INDEX = base.GlyphsTensorColumns.STAFF_INDEX
Y_POSITION = base.GlyphsTensorColumns.Y_POSITION
X = base.GlyphsTensorColumns.X
TYPE = base.GlyphsTensorColumns.TYPE
class ConvolutionalTest(tf.test.TestCase):
def testGetGlyphsPage(self):
# Refer to testing.py for the glyphs array.
# pyformat: disable
glyphs = pd.DataFrame(
[
{STAFF_INDEX: 0, Y_POSITION: 0, X: 0, TYPE: 3},
{STAFF_INDEX: 0, Y_POSITION: -1, X: 1, TYPE: 4},
{STAFF_INDEX: 0, Y_POSITION: 0, X: 2, TYPE: 5},
{STAFF_INDEX: 0, Y_POSITION: 1, X: 4, TYPE: 2},
{STAFF_INDEX: 1, Y_POSITION: 1, X: 2, TYPE: 3},
{STAFF_INDEX: 1, Y_POSITION: 0, X: 2, TYPE: 5},
{STAFF_INDEX: 1, Y_POSITION: -1, X: 4, TYPE: 3},
{STAFF_INDEX: 1, Y_POSITION: -1, X: 5, TYPE: 5},
],
columns=[STAFF_INDEX, Y_POSITION, X, TYPE])
# Compare glyphs (rows in the glyphs array) regardless of their position in
# the array (they are not required to be sorted).
self.assertEqual(
set(
map(tuple,
convolutional.Convolutional1DGlyphClassifier(
run_min_length=1)._build_detected_glyphs(
testing.PREDICTIONS))),
set(map(tuple, glyphs.values)))
def testNoGlyphs_dummyClassifier(self):
class DummyClassifier(convolutional.Convolutional1DGlyphClassifier):
"""Outputs the classifications for no glyphs on multiple staves."""
@property
def staffline_predictions(self):
return tf.fill([5, 9, 100], musicscore_pb2.Glyph.NONE)
with self.test_session():
self.assertAllEqual(
DummyClassifier().get_detected_glyphs().eval(),
np.zeros((0, 4), np.int32))
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/glyphs/corpus.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Labeled glyph corpus.
Reads Examples holding image patches and glyph records from TFRecords.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def get_patch_shape(corpus_file):
"""Gets the patch shape (height, width) from the corpus file.
Args:
corpus_file: Path to a TFRecords file.
Returns:
A tuple (height, width), extracted from the first record.
Raises:
ValueError: if the corpus_file is empty.
"""
example = tf.train.Example()
try:
example.ParseFromString(next(tf.python_io.tf_record_iterator(corpus_file)))
except StopIteration as e:
raise ValueError('corpus_file cannot be empty: %s' % e)
return (example.features.feature['height'].int64_list.value[0],
example.features.feature['width'].int64_list.value[0])
def parse_corpus(corpus_file, height, width):
"""Returns tensors holding the parsed result of the corpus file.
Uses the default TensorFlow session to read examples.
Args:
corpus_file: Path to a TFRecords file.
height: Patch height, as returned from `get_patch_shape`.
width: Patch width, as returned from `get_patch_width`.
Returns:
patches: float32 tensor with shape (num_patches, height, width).
labels: int64 tensor with shape (num_patches,).
"""
sess = tf.get_default_session()
producer = tf.train.string_input_producer([corpus_file], num_epochs=1)
unused_keys, examples = tf.TFRecordReader().read_up_to(producer, 10000)
parsed_examples = tf.parse_example(
examples, {
'patch': tf.FixedLenFeature((height, width), tf.float32),
'label': tf.FixedLenFeature((), tf.int64)
})
sess.run(tf.local_variables_initializer()) # initialize num_epochs
coord = tf.train.Coordinator()
queue_runners = tf.train.start_queue_runners(start=True, coord=coord)
assert queue_runners, 'started queue runners'
all_patches = []
all_labels = []
while True:
try:
patch, label = sess.run(
[parsed_examples['patch'], parsed_examples['label']])
except tf.errors.OutOfRangeError:
break # done
all_patches.append(patch)
all_labels.append(label)
coord.request_stop()
coord.join()
return np.concatenate(all_patches), np.concatenate(all_labels)
================================================
FILE: moonlight/glyphs/geometry.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Glyph y coordinate calculation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def glyph_y(staff, glyph):
"""Calculates the glyph y coordinate.
Args:
staff: A Staff, used for interpolating the staff center y coordinate.
glyph: A Glyph on the Staff.
Returns:
The y coordinate of the glyph on the page.
Raises:
ValueError: If the glyph is not contained by the interval spanned by the
staff on the x axis.
"""
for point_a, point_b in zip(staff.center_line[:-1], staff.center_line[1:]):
if point_a.x <= glyph.x < point_b.x:
staff_center_y = point_a.y + ((point_b.y - point_a.y) *
(glyph.x - point_a.x) //
(point_b.x - point_a.x))
# y positions count up (in the negative y direction).
return staff_center_y - staff.staffline_distance * glyph.y_position // 2
raise ValueError('Glyph (%s) is not contained by staff (%s)' %
(glyph, staff.center_line))
================================================
FILE: moonlight/glyphs/glyph_types.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility for glyph types.
Determines which modifiers may be attached to which glyphs (currently just
noteheads).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.protobuf import musicscore_pb2
def is_notehead(glyph):
return glyph.type in [
musicscore_pb2.Glyph.NOTEHEAD_EMPTY, musicscore_pb2.Glyph.NOTEHEAD_FILLED,
musicscore_pb2.Glyph.NOTEHEAD_WHOLE
]
def is_stemmed_notehead(glyph):
return glyph.type in [
musicscore_pb2.Glyph.NOTEHEAD_EMPTY, musicscore_pb2.Glyph.NOTEHEAD_FILLED
]
def is_beamed_notehead(glyph):
return glyph.type == musicscore_pb2.Glyph.NOTEHEAD_FILLED
def is_dotted_notehead(glyph):
# Any notehead can be dotted.
return is_notehead(glyph)
def is_clef(glyph):
return glyph.type in [
musicscore_pb2.Glyph.CLEF_TREBLE, musicscore_pb2.Glyph.CLEF_BASS
]
def is_rest(glyph):
return glyph.type in [
musicscore_pb2.Glyph.REST_QUARTER,
musicscore_pb2.Glyph.REST_EIGHTH,
musicscore_pb2.Glyph.REST_SIXTEENTH,
]
================================================
FILE: moonlight/glyphs/knn.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""K-Nearest-Neighbors glyph classification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.glyphs import convolutional
from moonlight.glyphs import corpus
from moonlight.protobuf import musicscore_pb2
from moonlight.util import patches
# k = 3 has the best performance for noteheads, clefs, and sharps. k = 5 seems
# to increase false negatives, so we probably don't want to increase k further
# with our current data.
K_NEAREST_VALUE = 3
NUM_GLYPHS = len(musicscore_pb2.Glyph.Type.keys())
class NearestNeighborGlyphClassifier(
convolutional.Convolutional1DGlyphClassifier):
"""Classifies staffline patches using 1 nearest neighbor."""
def __init__(self, corpus_file, staffline_extractor, **kwargs):
"""Build a 1-nearest-neighbor classifier with labeled patches.
Args:
corpus_file: Path to the TFRecords of Examples with patch (cluster) values
in the "patch" feature, and the glyph label in the "label" feature.
staffline_extractor: The staffline extractor.
**kwargs: Passed through to `Convolutional1DGlyphClassifier`.
"""
super(NearestNeighborGlyphClassifier, self).__init__(**kwargs)
patch_height, patch_width = corpus.get_patch_shape(corpus_file)
centroids, labels = corpus.parse_corpus(corpus_file, patch_height,
patch_width)
centroids_shape = tf.shape(centroids)
flattened_centroids = tf.reshape(
centroids,
[centroids_shape[0], centroids_shape[1] * centroids_shape[2]])
self.staffline_extractor = staffline_extractor
stafflines = staffline_extractor.extract_staves()
# Collapse the stafflines per stave.
width = tf.shape(stafflines)[-1]
# Shape (num_staves, num_stafflines, num_patches, height, patch_width).
staffline_patches = patches.patches_1d(stafflines, patch_width)
staffline_patches_shape = tf.shape(staffline_patches)
flattened_patches = tf.reshape(staffline_patches, [
staffline_patches_shape[0] * staffline_patches_shape[1] *
staffline_patches_shape[2],
staffline_patches_shape[3] * staffline_patches_shape[4]
])
distance_matrix = _squared_euclidean_distance_matrix(
flattened_patches, flattened_centroids)
# Take the k centroids with the lowest distance to each patch. Wrap the k
# constant in a tf.identity, which tests can use to feed in another value.
k_value = tf.identity(tf.constant(K_NEAREST_VALUE), name='k_nearest_value')
nearest_centroid_inds = tf.nn.top_k(-distance_matrix, k=k_value)[1]
# Get the label corresponding to each nearby centroids, and reshape the
# labels back to the original shape.
nearest_labels = tf.reshape(
tf.gather(labels, tf.reshape(nearest_centroid_inds, [-1])),
tf.shape(nearest_centroid_inds))
# Make a histogram of counts for each glyph type in the nearest centroids,
# for each row (patch).
bins = tf.map_fn(lambda row: tf.bincount(row, minlength=NUM_GLYPHS),
tf.to_int32(nearest_labels))
# Take the argmax of the histogram to get the top prediction. Discard glyph
# type 1 (NONE) for now.
mode_out_of_k = tf.argmax(
bins[:, musicscore_pb2.Glyph.NONE + 1:], axis=1) + 2
# Force predictions to NONE only if all k nearby centroids were NONE.
# Otherwise, the non-NONE nearby centroids will contribute to the
# prediction.
mode_out_of_k = tf.where(
tf.equal(bins[:, musicscore_pb2.Glyph.NONE], k_value),
tf.fill(
tf.shape(mode_out_of_k), tf.to_int64(musicscore_pb2.Glyph.NONE)),
mode_out_of_k)
predictions = tf.reshape(mode_out_of_k, staffline_patches_shape[:3])
# Pad the output.
predictions_width = tf.shape(predictions)[-1]
pad_before = (width - predictions_width) // 2
pad_shape_before = tf.concat([staffline_patches_shape[:2], [pad_before]],
axis=0)
pad_shape_after = tf.concat(
[staffline_patches_shape[:2], [width - predictions_width - pad_before]],
axis=0)
self.output = tf.concat(
[
# NONE has value 1.
tf.ones(pad_shape_before, tf.int64),
predictions,
tf.ones(pad_shape_after, tf.int64),
],
axis=-1)
@property
def staffline_predictions(self):
return self.output
def _squared_euclidean_distance_matrix(a, b):
# Trick for computing the squared Euclidean distance matrix.
# Entry (i, j) = a[i].sum() + b[j].sum() - 2 * (a[i] * b[j]).sum()
# = sum_k (a[i, k] + b[j, k] - 2 * a[i, k] * b[j, k])
# = sum_k (a[i, k] - b[j, k]) ** 2
a_sum = tf.reshape(tf.reduce_sum(a, axis=1), [-1, 1]) # column vector
b_sum = tf.reshape(tf.reduce_sum(b, axis=1), [1, -1]) # row vector
return a_sum + b_sum - 2 * tf.matmul(a, b, transpose_b=True)
================================================
FILE: moonlight/glyphs/knn_model.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""K-Nearest-Neighbors glyph classification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow_estimator.python.estimator.canned import prediction_keys
from moonlight.protobuf import musicscore_pb2
# k = 3 has the best performance for noteheads, clefs, and sharps. k = 5 seems
# to increase false negatives, so we probably don't want to increase k further
# with our current data.
K_NEAREST_VALUE = 3
NUM_GLYPHS = len(musicscore_pb2.Glyph.Type.keys())
def knn_kmeans_model(centroids, labels, patches=None):
"""The KNN k-means classifier model.
Args:
centroids: The k-means centroids NumPy array. Shape `(num_centroids,
patch_height, patch_width)`.
labels: The centroid labels NumPy array. Vector with length `num_centroids`.
patches: Optional input tensor for the patches. If None, a placeholder will
be used.
Returns:
The predictions (class ids) tensor determined from the input patches. Vector
with the same length as `patches`.
"""
with tf.name_scope('knn_model'):
centroids = tf.identity(
_to_float(tf.constant(_to_uint8(centroids))), name='centroids')
labels = tf.constant(labels, name='labels')
centroids_shape = tf.shape(centroids)
num_centroids = centroids_shape[0]
patch_height = centroids_shape[1]
patch_width = centroids_shape[2]
flattened_centroids = tf.reshape(
centroids, [num_centroids, patch_height * patch_width],
name='flattened_centroids')
if patches is None:
patches = tf.placeholder(
tf.float32, (None, centroids.shape[1], centroids.shape[2]),
name='patches')
patches_shape = tf.shape(patches)
flattened_patches = tf.reshape(
patches, [patches_shape[0], patches_shape[1] * patches_shape[2]],
name='flattened_patches')
with tf.name_scope('distance_matrix'):
distance_matrix = _squared_euclidean_distance_matrix(
flattened_patches, flattened_centroids)
# Take the k centroids with the lowest distance to each patch. Wrap the k
# constant in a tf.identity, which tests can use to feed in another value.
k_value = tf.identity(tf.constant(K_NEAREST_VALUE), name='k_nearest_value')
nearest_centroid_inds = tf.nn.top_k(-distance_matrix, k=k_value)[1]
# Get the label corresponding to each nearby centroids, and reshape the
# labels back to the original shape.
nearest_labels = tf.reshape(
tf.gather(labels, tf.reshape(nearest_centroid_inds, [-1])),
tf.shape(nearest_centroid_inds),
name='nearest_labels')
# Make a histogram of counts for each glyph type in the nearest centroids,
# for each row (patch).
length = NUM_GLYPHS
bins = tf.map_fn(
lambda row: tf.bincount(row, minlength=length, maxlength=length),
tf.to_int32(nearest_labels),
name='bins')
with tf.name_scope('mode_out_of_k'):
# Take the argmax of the histogram to get the top prediction. Discard
# glyph type 1 (NONE) for now.
mode_out_of_k = tf.argmax(
bins[:, musicscore_pb2.Glyph.NONE + 1:], axis=1) + 2
# Force predictions to NONE only if all k nearby centroids were NONE.
# Otherwise, the non-NONE nearby centroids will contribute to the
# prediction.
mode_out_of_k = tf.where(
tf.equal(bins[:, musicscore_pb2.Glyph.NONE], k_value),
tf.fill(
tf.shape(mode_out_of_k), tf.to_int64(musicscore_pb2.Glyph.NONE)),
mode_out_of_k)
return tf.identity(mode_out_of_k, name='predictions')
def _to_uint8(values):
return np.rint(values * 255).astype(np.uint8)
def _to_float(values_t):
return tf.to_float(values_t) / tf.constant(255.)
def export_knn_model(centroids, labels, export_path):
"""Writes the KNN saved model.
Args:
centroids: The k-means centroids NumPy array.
labels: The labels of the k-means centroids.
export_path: The output saved model directory.
"""
g = tf.Graph()
with g.as_default():
predictions = knn_kmeans_model(centroids, labels)
patches = g.get_tensor_by_name('knn_model/patches:0')
predictions_info = tf.saved_model.utils.build_tensor_info(predictions)
patches_info = tf.saved_model.utils.build_tensor_info(patches)
with tf.Session(graph=g) as sess:
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess, ['serve'],
signature_def_map={
tf.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': patches_info},
outputs={
prediction_keys.PredictionKeys.CLASS_IDS:
predictions_info
}),
})
builder.save()
def _squared_euclidean_distance_matrix(a, b):
# Trick for computing the squared Euclidean distance matrix.
# Entry (i, j) = a[i].sum() + b[j].sum() - 2 * (a[i] * b[j]).sum()
# = sum_k (a[i, k] + b[j, k] - 2 * a[i, k] * b[j, k])
# = sum_k (a[i, k] - b[j, k]) ** 2
a_sum = tf.reshape(tf.reduce_sum(a, axis=1), [-1, 1]) # column vector
b_sum = tf.reshape(tf.reduce_sum(b, axis=1), [1, -1]) # row vector
return a_sum + b_sum - 2 * tf.matmul(a, b, transpose_b=True)
================================================
FILE: moonlight/glyphs/knn_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the KNN glyph classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
import pandas as pd
import tensorflow as tf
from moonlight.glyphs import base
from moonlight.glyphs import knn
from moonlight.protobuf import musicscore_pb2
STAFF_INDEX = base.GlyphsTensorColumns.STAFF_INDEX
Y_POSITION = base.GlyphsTensorColumns.Y_POSITION
X = base.GlyphsTensorColumns.X
TYPE = base.GlyphsTensorColumns.TYPE
Glyph = musicscore_pb2.Glyph # pylint: disable=invalid-name
class KnnTest(tf.test.TestCase):
def testFakeStaffline(self):
# Staffline containing fake glyphs.
staffline = tf.constant(
[[1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1],
[1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1]])
staffline = tf.cast(staffline, tf.float32)
# Mapping of glyphs (row-major order) to glyph types.
patterns = {
(1, 0, 1, 0, 1, 0, 1, 0, 1): musicscore_pb2.Glyph.CLEF_TREBLE,
(0, 0, 0, 0, 1, 0, 0, 0, 0): musicscore_pb2.Glyph.NOTEHEAD_FILLED,
(1, 1, 1, 1, 0, 1, 1, 1, 1): musicscore_pb2.Glyph.NOTEHEAD_EMPTY,
(0, 0, 0, 0, 0, 0, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(1, 0, 0, 0, 0, 0, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 1, 0, 0, 0, 0, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 1, 0, 0, 0, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 0, 1, 0, 0, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 0, 0, 0, 1, 0, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 0, 0, 0, 0, 1, 0, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 0, 0, 0, 0, 0, 1, 0): musicscore_pb2.Glyph.NONE,
(0, 0, 0, 0, 0, 0, 0, 0, 1): musicscore_pb2.Glyph.NONE,
}
with tf.Session():
with tempfile.NamedTemporaryFile(mode='r') as examples_file:
with tf.python_io.TFRecordWriter(examples_file.name) as writer:
# Sort the keys for determinism.
for pattern in sorted(patterns):
example = tf.train.Example()
example.features.feature['patch'].float_list.value.extend(pattern)
example.features.feature['height'].int64_list.value.append(3)
example.features.feature['width'].int64_list.value.append(3)
example.features.feature['label'].int64_list.value.append(
patterns[pattern])
writer.write(example.SerializeToString())
class FakeStafflineExtractor(object):
def extract_staves(self):
return staffline[None, None, :, :]
# stafflines are 4D (num_staves, num_stafflines, height, width).
classifier = knn.NearestNeighborGlyphClassifier(
examples_file.name, FakeStafflineExtractor(), run_min_length=1)
k_nearest_value = tf.get_default_graph().get_tensor_by_name(
'k_nearest_value:0')
glyphs = classifier.get_detected_glyphs().eval(
feed_dict={k_nearest_value: 1})
# The patches of the staffline that match non-NONE patterns (in row-major
# order) should appear here (x is their center coordinate).
# pyformat: disable
expected_glyphs = pd.DataFrame(
[
{STAFF_INDEX: 0, Y_POSITION: 0, X: 1, TYPE: Glyph.CLEF_TREBLE},
{STAFF_INDEX: 0, Y_POSITION: 0, X: 9, TYPE: Glyph.NOTEHEAD_EMPTY},
{STAFF_INDEX: 0, Y_POSITION: 0, X: 14, TYPE: Glyph.CLEF_TREBLE},
{STAFF_INDEX: 0, Y_POSITION: 0, X: 19, TYPE: Glyph.NOTEHEAD_FILLED},
],
columns=[STAFF_INDEX, Y_POSITION, X, TYPE])
self.assertAllEqual(glyphs, expected_glyphs)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/glyphs/neural.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""1-D convolutional neural network glyph classifier model.
Convolves a filter horizontally along a staffline, to classify glyphs at each
x position.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.protobuf import musicscore_pb2
# Count every glyph type except for UNKNOWN_TYPE.
NUM_GLYPHS = len(musicscore_pb2.Glyph.Type.values()) - 1
# TODO(ringw): Make this extend BaseGlyphClassifier.
class NeuralNetworkGlyphClassifier(object):
"""Holds a TensorFlow NN model used for classifying glyphs on staff lines."""
def __init__(self,
input_placeholder,
hidden_layer,
reconstruction_layer=None,
autoencoder_vars=None,
labels_placeholder=None,
prediction_layer=None,
prediction_vars=None):
"""Builds the NeuralNetworkGlyphClassifier that holds the TensorFlow model.
Args:
input_placeholder: A tf.placeholder representing the input staffline
image. Dtype float32 and shape (batch_size, target_height, None).
hidden_layer: An inner layer in the model. Should be the last layer in the
autoencoder model before reconstructing the input, and/or an
intermediate layer in the prediction network. self is intended to be the
last common ancestor of the reconstruction_layer output and the
prediction_layer output, if both are present.
reconstruction_layer: The reconstruction of the input, for an autoencoder
model. If non-None, should have the same shape as input_placeholder.
autoencoder_vars: The variables for the autoencoder model (parameters
affecting hidden_layer and reconstruction_layer), or None. If non-None,
a dict mapping variable name to tf.Variable object.
labels_placeholder: The labels tensor. A placeholder will be created if
None is given. Dtype int32 and shape (batch_size, width). Values are
between 0 and NUM_GLYPHS - 1 (where each value is the Glyph.Type enum
value minus one, to skip UNKNOWN_TYPE).
prediction_layer: The logit probability of each glyph for each column.
Must be able to be passed to tf.nn.softmax to produce the probability of
each glyph. 2D (width, NUM_GLYPHS). May be None if the model is not
being used for classification.
prediction_vars: The variables for the classification model (parameters
affecting hidden_layer and prediction_layer), or None. If non-None, a
dict mapping variable name to tf.Variable object.
"""
self.input_placeholder = input_placeholder
self.hidden_layer = hidden_layer
self.reconstruction_layer = reconstruction_layer
self.autoencoder_vars = autoencoder_vars or {}
# Calculate the loss that will be minimized for the autoencoder model.
self.autoencoder_loss = None
if self.reconstruction_layer is not None:
self.autoencoder_loss = (
tf.reduce_mean(
tf.squared_difference(self.input_placeholder,
self.reconstruction_layer)))
self.prediction_layer = prediction_layer
self.prediction_vars = prediction_vars or {}
self.labels_placeholder = (
labels_placeholder if labels_placeholder is not None else
tf.placeholder(tf.int32, (None, None)))
# Calculate the loss that will be minimized for the prediction model.
self.prediction_loss = None
if self.prediction_layer is not None:
self.prediction_loss = (
tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=self.prediction_layer,
labels=tf.one_hot(self.labels_placeholder, NUM_GLYPHS))))
# The probabilities of each glyph for each column.
self.prediction = tf.nn.softmax(self.prediction_layer)
def get_autoencoder_initializers(self):
"""Gets the autoencoder initializer ops.
Returns:
The list of TensorFlow ops which initialize the autoencoder model.
"""
return [var.initializer for var in self.autoencoder_vars.values()]
def get_classifier_initializers(self):
"""Gets the classifier initializer ops.
Returns:
The list of TensorFlow ops which initialize the classifier model.
"""
return [var.initializer for var in self.prediction_vars.values()]
@staticmethod
def semi_supervised_model(batch_size,
target_height,
input_placeholder=None,
labels_placeholder=None):
"""Constructs the semi-supervised model.
Consists of an autoencoder and classifier, sharing a hidden layer.
Args:
batch_size: The number of staffline images in a batch, which must be known
at model definition time. int.
target_height: The height of each scaled staffline image. int.
input_placeholder: The input layer. A placeholder will be created if None
is given. Dtype float32 and shape (batch_size, target_height,
any_width).
labels_placeholder: The labels tensor. A placeholder will be created if
None is given. Dtype int32 and shape (batch_size, width).
Returns:
A NeuralNetworkGlyphClassifier instance holding the model.
"""
if input_placeholder is None:
input_placeholder = tf.placeholder(tf.float32,
(batch_size, target_height, None))
autoencoder_vars = {}
prediction_vars = {}
hidden, layer_vars = InputConvLayer(input_placeholder, 10).get()
autoencoder_vars.update(layer_vars)
prediction_vars.update(layer_vars)
hidden, layer_vars = HiddenLayer(hidden, 10, 10).get()
autoencoder_vars.update(layer_vars)
prediction_vars.update(layer_vars)
reconstruction, layer_vars = ReconstructionLayer(hidden, target_height,
target_height).get()
autoencoder_vars.update(layer_vars)
hidden, layer_vars = HiddenLayer(hidden, 10, 10, name="hidden_2").get()
prediction_vars.update(layer_vars)
prediction, layer_vars = PredictionLayer(hidden).get()
prediction_vars.update(layer_vars)
return NeuralNetworkGlyphClassifier(
input_placeholder,
hidden,
reconstruction_layer=reconstruction,
autoencoder_vars=autoencoder_vars,
labels_placeholder=labels_placeholder,
prediction_layer=prediction,
prediction_vars=prediction_vars)
class BaseLayer(object):
def __init__(self, filter_size, n_in, n_out, name):
self.weights = tf.Variable(
tf.truncated_normal((filter_size, n_in, n_out)), name=name + "_W")
self.bias = tf.Variable(tf.zeros(n_out), name=name + "_bias")
self.vars = {self.weights.name: self.weights, self.bias.name: self.bias}
def get(self):
"""Gets the layer output and variables.
Returns:
The output tensor of the layer.
The dict of variables (parameters) for the layer.
"""
return self.output, self.vars
class InputConvLayer(BaseLayer):
"""Convolves the input image strip, producing multiple outputs per column."""
def __init__(self, image, n_hidden, activation=tf.nn.sigmoid, name="input"):
"""Creates the InputConvLayer.
Args:
image: The input image (height, width). Should be wider than it is tall.
n_hidden: The number of output nodes of the layer.
activation: Callable applied to the convolved image. Applied to the 1D
convolution result to produce the activation of the layer.
name: The prefix for variable names for the layer. Produces self.output
with shape (width, n_hidden).
"""
height = int(image.get_shape()[1])
super(InputConvLayer, self).__init__(
filter_size=height, n_in=height, n_out=n_hidden, name=name)
self.input = image
# Transpose the image, so that the rows are "channels" in a 1D input.
self.output = activation(
tf.nn.conv1d(
tf.transpose(image, [0, 2, 1]),
self.weights,
stride=1,
padding="SAME") + self.bias[None, None, :])
class HiddenLayer(BaseLayer):
"""Performs a 1D convolution between hidden layers in the model."""
def __init__(self,
layer_in,
filter_size,
n_out,
activation=tf.nn.sigmoid,
name="hidden"):
"""Performs a 1D convolution between hidden layers in the model.
Args:
layer_in: The input layer (width, num_channels).
filter_size: The width of the convolution filter.
n_out: The number of output channels.
activation: Callable applied to the convolved image. Applied to the 1D
convolution result to produce the activation of the layer.
name: The prefix for variable names for the layer. Produces self.output
with shape (width, n_out).
"""
n_in = int(layer_in.get_shape()[2])
super(HiddenLayer, self).__init__(filter_size, n_in, n_out, name)
self.output = activation(
tf.nn.conv1d(layer_in, self.weights, stride=1, padding="SAME") +
self.bias[None, None, :])
class ReconstructionLayer(BaseLayer):
"""Outputs a reconstructed layer."""
def __init__(self,
layer_in,
filter_size,
out_height,
activation=tf.nn.sigmoid,
name="reconstruction"):
"""Outputs a reconstructed image of shape (out_height, width).
Args:
layer_in: The input layer (width, num_channels).
filter_size: The width of the convolution filter.
out_height: The height of the output image.
activation: Callable applied to the convolved image. Applied to the 1D
convolution result to produce the activation of the output.
name: The prefix for variable names for the layer. Produces self.output
with shape (width, n_out).
"""
n_in = int(layer_in.get_shape()[2])
super(ReconstructionLayer, self).__init__(filter_size, n_in, out_height,
name)
output = activation(
tf.nn.conv1d(layer_in, self.weights, stride=1, padding="SAME") +
self.bias[None, None, :])
self.output = tf.transpose(output, [0, 2, 1])
class PredictionLayer(BaseLayer):
"""Classifies each column from a hidden layer."""
def __init__(self, layer_in, name="prediction"):
"""Outputs logit predictions for each column from a hidden layer.
Args:
layer_in: The input layer (width, num_channels).
name: The prefix for variable names for the layer. Produces the logits
for each class in self.output. Shape (width, NUM_GLYPHS)
"""
n_in = int(layer_in.get_shape()[2])
n_out = NUM_GLYPHS
super(PredictionLayer, self).__init__(1, n_in, n_out, name)
input_shape = tf.shape(layer_in)
input_columns = tf.reshape(
layer_in, [input_shape[0] * input_shape[1], input_shape[2]])
# Ignore the 0th axis of the weights (convolutional filter, which is 1 here)
weights = self.weights[0, :, :]
output = tf.matmul(input_columns, weights) + self.bias
self.output = tf.reshape(output,
[input_shape[0], input_shape[1], NUM_GLYPHS])
================================================
FILE: moonlight/glyphs/neural_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the neural network glyph classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.glyphs import neural
NUM_STAFFLINES = 5
TARGET_HEIGHT = 15
IMAGE_WIDTH = 100
class NeuralNetworkGlyphClassifierTest(tf.test.TestCase):
def testSemiSupervisedClassifier(self):
# Ensure that the losses can be evaluated without error.
stafflines = tf.random_uniform((NUM_STAFFLINES, TARGET_HEIGHT, IMAGE_WIDTH))
# Use every single glyph once except for NONE (0).
labels_single_batch = np.concatenate([
np.arange(neural.NUM_GLYPHS),
np.zeros(IMAGE_WIDTH - neural.NUM_GLYPHS)
]).astype(np.int32)
labels = np.repeat(labels_single_batch[None, :], NUM_STAFFLINES, axis=0)
classifier = neural.NeuralNetworkGlyphClassifier.semi_supervised_model(
batch_size=NUM_STAFFLINES,
target_height=TARGET_HEIGHT,
input_placeholder=stafflines,
labels_placeholder=tf.constant(labels))
with self.test_session() as sess:
# The autoencoder must run successfully with only its vars initialized.
# The loss must always be positive.
sess.run(classifier.get_autoencoder_initializers())
self.assertGreater(sess.run(classifier.autoencoder_loss), 0)
# The classifier must run successfully with its vars initialized too.
sess.run(classifier.get_classifier_initializers())
self.assertGreater(sess.run(classifier.prediction_loss), 0)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/glyphs/note_dots.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detects dots which are attached to noteheads.
Dots are round, solid, smaller than other glyphs, and are typically spaced so
that they don't intersect with staff lines. Therefore, we detect them from the
connected components. We determine that the component is round-ish and solid if
the area (black pixel count) is at least half the area of the bounds of the
component.
Candidate note dots are components that are round-ish and follow the expected
size. For each notehead, we look for candidate dots slightly to the right of the
note to assign to it.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from moonlight.glyphs import geometry
from moonlight.glyphs import glyph_types
from moonlight.protobuf import musicscore_pb2
from moonlight.structure import components
COMPONENTS = components.ConnectedComponentsColumns
class NoteDots(object):
def __init__(self, structure):
self.dots = _extract_dots(structure)
def apply(self, page):
"""Detects note dots in the page.
Dots must be to the right of a notehead.
Args:
page: A `Page` message.
Returns:
The same `Page`, with note dots added in place.
"""
for system in page.system:
for staff in system.staff:
for glyph in staff.glyph:
if glyph_types.is_dotted_notehead(glyph):
x_min = glyph.x
x_max = glyph.x + staff.staffline_distance * 3.
y = geometry.glyph_y(staff, glyph)
y_min = y - staff.staffline_distance / 2.
y_max = y + staff.staffline_distance / 2.
dots = self.dots[_is_in_range(x_min, self.dots[:, 0], x_max)
& _is_in_range(y_min, self.dots[:, 1], y_max)]
glyph.dot.extend(
musicscore_pb2.Point(x=dot[0], y=dot[1]) for dot in dots)
return page
def _is_in_range(min_value, values, max_value):
return np.logical_and(min_value <= values, values <= max_value)
def _extract_dots(structure):
"""Returns candidate note dots.
Note dots must be connected components which are roundish (the area of the
component's bounds is at least half full), and are the expected size.
Args:
structure: A computed `Structure`.
Returns:
A numpy array of shape `(N, 2)`. Each entry holds the center `(x, y)` of a
candidate note dot.
"""
min_height_width = structure.staff_detector.staffline_thickness + 1
# TODO(ringw): Are note dots typically smaller in ossia parts?
max_height_width = np.median(
structure.staff_detector.staffline_distance) * 2 / 3
connected_components = structure.connected_components.components
width = connected_components[:, COMPONENTS
.X1] - connected_components[:, COMPONENTS.X0]
height = connected_components[:, COMPONENTS
.Y1] - connected_components[:, COMPONENTS.Y0]
is_full = np.greater_equal(connected_components[:, COMPONENTS.SIZE] * 2,
width * height)
candidates = connected_components[
is_full
& _is_in_range(min_height_width, width, max_height_width)
& _is_in_range(min_height_width, height, max_height_width)]
# pyformat would make this completely unreadable
# pyformat: disable
candidate_centers = (
np.c_[
(candidates[:, COMPONENTS.X0] + candidates[:, COMPONENTS.X1]) / 2,
(candidates[:, COMPONENTS.Y0] + candidates[:, COMPONENTS.Y1]) / 2]
.astype(int))
return candidate_centers
================================================
FILE: moonlight/glyphs/repeated.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fixes duplicate rests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.glyphs import glyph_types
class FixRepeatedRests(object):
def apply(self, page):
"""Remove duplicate rests of the same type."""
for system in page.system:
for staff in system.staff:
to_remove = []
last_rest = None
for glyph in staff.glyph:
if (last_rest and glyph_types.is_rest(glyph) and
last_rest.type == glyph.type and
glyph.x - last_rest.x < staff.staffline_distance):
to_remove.append(glyph)
last_rest = glyph
for glyph in to_remove:
staff.glyph.remove(glyph)
return page
================================================
FILE: moonlight/glyphs/saved_classifier.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Saved patch classifier models for OMR.
The saved model should accept a 3D tensor of patches
`(num_patches, patch_height, patch_width)`, and return a vector of class ids of
length `num_patches`. Horizontal patches are extracted from each vertical
position on each staff where we expect to find glyphs, and any arbitrary model
can be loaded here.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import graph_editor as contrib_graph_editor
from tensorflow.contrib import util as contrib_util
from tensorflow_estimator.python.estimator.canned import prediction_keys
from moonlight.glyphs import convolutional
from moonlight.staves import staffline_extractor
from moonlight.util import patches
_SIGNATURE_KEYS = [
# Created if the ServingInputReceiver has 'patch' in
# receiver_tensors_alternatives.
'patch:predict',
# Seems to only be created if receiver_tensors_alternatives is not set.
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
]
_RUN_MIN_LENGTH_CONSTANT_NAME = 'runtime_params/run_min_length:0'
class SavedConvolutional1DClassifier(
convolutional.Convolutional1DGlyphClassifier):
"""Holds a saved glyph classifier model.
To use a saved glyph classifier with `OMREngine`, see the
`saved_classifier_fn` wrapper.
"""
def __init__(self,
structure,
saved_model_dir,
num_sections=19,
*args,
**kwargs):
"""Loads a saved classifier model for the OMR engine.
Args:
structure: A `structure.Structure`.
saved_model_dir: Path to the TF saved_model directory to load.
num_sections: Number of vertical positions of patches to extract, centered
on the middle staff line.
*args: Passed through to `SavedConvolutional1DClassifier`.
**kwargs: Passed through to `SavedConvolutional1DClassifier`.
Raises:
ValueError: If the saved model input could not be interpreted as a 3D
array with the patch size.
"""
super(SavedConvolutional1DClassifier, self).__init__(*args, **kwargs)
sess = tf.get_default_session()
graph_def = tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], saved_model_dir)
signature = None
for key in _SIGNATURE_KEYS:
if key in graph_def.signature_def:
signature = graph_def.signature_def[key]
break
else:
# for/else is only executed if the loop completes without breaking.
raise ValueError('One of the following signatures must be present: %s' %
_SIGNATURE_KEYS)
input_info = signature.inputs['input']
if not (len(input_info.tensor_shape.dim) == 3 and
input_info.tensor_shape.dim[1].size > 0 and
input_info.tensor_shape.dim[2].size > 0):
raise ValueError('Invalid patches input: ' + str(input_info))
patch_height = input_info.tensor_shape.dim[1].size
patch_width = input_info.tensor_shape.dim[2].size
with tf.name_scope('saved_classifier'):
self.staffline_extractor = staffline_extractor.StafflineExtractor(
structure.staff_remover.remove_staves,
structure.staff_detector,
num_sections=num_sections,
target_height=patch_height)
stafflines = self.staffline_extractor.extract_staves()
num_staves = tf.shape(stafflines)[0]
num_sections = tf.shape(stafflines)[1]
staffline_patches = patches.patches_1d(stafflines, patch_width)
staffline_patches_shape = tf.shape(staffline_patches)
patches_per_position = staffline_patches_shape[2]
flat_patches = tf.reshape(staffline_patches, [
num_staves * num_sections * patches_per_position, patch_height,
patch_width
])
# Feed in the flat extracted patches as the classifier input.
predictions_name = signature.outputs[
prediction_keys.PredictionKeys.CLASS_IDS].name
predictions = contrib_graph_editor.graph_replace(
sess.graph.get_tensor_by_name(predictions_name), {
sess.graph.get_tensor_by_name(signature.inputs['input'].name):
flat_patches
})
# Reshape to the original patches shape.
predictions = tf.reshape(predictions, staffline_patches_shape[:3])
# Pad the output. We take only the valid patches, but we want to shift all
# of the predictions so that a patch at index i on the x-axis is centered
# on column i. This determines the x coordinates of the glyphs.
width = tf.shape(stafflines)[-1]
predictions_width = tf.shape(predictions)[-1]
pad_before = (width - predictions_width) // 2
pad_shape_before = tf.concat([staffline_patches_shape[:2], [pad_before]],
axis=0)
pad_shape_after = tf.concat([
staffline_patches_shape[:2], [width - predictions_width - pad_before]
],
axis=0)
self.output = tf.concat(
[
# NONE has value 1.
tf.ones(pad_shape_before, tf.int64),
tf.to_int64(predictions),
tf.ones(pad_shape_after, tf.int64),
],
axis=-1)
# run_min_length can be set on the saved model to tweak its behavior, but
# should be overridden by the keyword argument.
if 'run_min_length' not in kwargs:
try:
# Try to read the run min length from the saved model. This is tweaked
# on a per-model basis.
run_min_length_t = sess.graph.get_tensor_by_name(
_RUN_MIN_LENGTH_CONSTANT_NAME)
run_min_length = contrib_util.constant_value(run_min_length_t)
# Implicit comparison is invalid on a NumPy array.
# pylint: disable=g-explicit-bool-comparison
if run_min_length is None or run_min_length.shape != ():
raise ValueError('Bad run_min_length: {}'.format(run_min_length))
# Overwrite the property after the Convolutional1DGlyphClassifier
# constructor completes.
self.run_min_length = int(run_min_length)
except KeyError:
pass # No run_min_length tensor in the saved model.
@property
def staffline_predictions(self):
return self.output
================================================
FILE: moonlight/glyphs/saved_classifier_fn.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides the default included OMR classifier.
This assumes a `saved_model.pb` file with no variables in the model (which would
be separate files).
This is an internal version that can be run in a .PAR file. This module will not
be shared between Piper and Git, which will have a simpler open source version.
The open source version is just a wrapper around the
`SavedConvolutional1DClassifier` ctor, which assumes that the saved model is a
real directory.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
from moonlight.glyphs import saved_classifier
_SAVED_MODEL_PATH = '../data/glyphs_nn_model_20180808'
def build_classifier_fn(saved_model=None):
"""Returns a glyph classifier fn for a saved model.
The result can be given to `OMREngine` to configure the saved model to use.
Args:
saved_model: Saved model directory. If None, uses the default KNN saved
model included with Magenta.
Returns:
A callable that accepts a `Structure` and returns a `BaseGlyphClassifier`.
"""
saved_model = (
saved_model or
os.path.join(tf.resource_loader.get_data_files_path(), _SAVED_MODEL_PATH))
ctor = saved_classifier.SavedConvolutional1DClassifier
return lambda structure: ctor(structure, saved_model)
================================================
FILE: moonlight/glyphs/saved_classifier_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests running OMR with a dummy saved model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tempfile
import numpy as np
import tensorflow as tf
from moonlight import image
from moonlight import structure
from moonlight.glyphs import convolutional
from moonlight.glyphs import saved_classifier
from moonlight.protobuf import musicscore_pb2
class SavedClassifierTest(tf.test.TestCase):
def testSaveAndLoadDummyClassifier(self):
with tempfile.TemporaryDirectory() as base_dir:
export_dir = os.path.join(base_dir, 'export')
with self.test_session() as sess:
patches = tf.placeholder(tf.float32, shape=(None, 18, 15))
num_patches = tf.shape(patches)[0]
# Glyph.NONE is number 1.
class_ids = tf.ones([num_patches], tf.int32)
signature = tf.saved_model.signature_def_utils.build_signature_def(
# pyformat: disable
{'input': tf.saved_model.utils.build_tensor_info(patches)},
{'class_ids': tf.saved_model.utils.build_tensor_info(class_ids)},
'serve')
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
sess, ['serve'],
signature_def_map={
tf.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature
})
builder.save()
tf.reset_default_graph()
# Load the saved model.
with self.test_session() as sess:
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
page = image.decode_music_score_png(tf.read_file(filename))
clazz = saved_classifier.SavedConvolutional1DClassifier(
structure.create_structure(page), export_dir)
# Run min length should be the default.
self.assertEqual(clazz.run_min_length,
convolutional.DEFAULT_RUN_MIN_LENGTH)
predictions = clazz.staffline_predictions.eval()
self.assertEqual(predictions.ndim, 3) # Staff, staff position, x
self.assertGreater(predictions.size, 0)
# Predictions are all musicscore_pb2.Glyph.NONE.
self.assertAllEqual(
predictions,
np.full(predictions.shape, musicscore_pb2.Glyph.NONE, np.int32))
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/glyphs/testing.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utilities for glyph classification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.glyphs import convolutional
from moonlight.protobuf import musicscore_pb2
# Sample glyph predictions.
# Shape (num_staves, num_stafflines, width).
PREDICTIONS = np.asarray(
[[[1, 1, 1, 1, 2, 1],
[3, 1, 5, 1, 1, 1],
[1, 4, 1, 1, 1, 1]],
[[1, 1, 3, 1, 1, 1],
[1, 1, 5, 1, 1, 1],
[1, 1, 1, 1, 3, 5]]]) # pyformat: disable
# Page corresponding to the glyphs in PREDICTIONS.
GLYPHS_PAGE = musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(staff=[
musicscore_pb2.Staff(glyph=[
musicscore_pb2.Glyph(x=0, y_position=0, type=3),
musicscore_pb2.Glyph(x=1, y_position=-1, type=4),
musicscore_pb2.Glyph(x=2, y_position=0, type=5),
musicscore_pb2.Glyph(x=4, y_position=1, type=2),
]),
musicscore_pb2.Staff(glyph=[
musicscore_pb2.Glyph(x=2, y_position=1, type=3),
musicscore_pb2.Glyph(x=2, y_position=0, type=5),
musicscore_pb2.Glyph(x=4, y_position=-1, type=3),
musicscore_pb2.Glyph(x=5, y_position=-1, type=5),
]),
]),
])
class DummyGlyphClassifier(convolutional.Convolutional1DGlyphClassifier):
"""A 1D convolutional glyph classifier with constant predictions.
The predictions have shape (num_staves, num_stafflines, width).
"""
def __init__(self, predictions):
super(DummyGlyphClassifier, self).__init__(run_min_length=1)
self.predictions = predictions
@property
def staffline_predictions(self):
return tf.constant(self.predictions)
================================================
FILE: moonlight/image.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility for reading music score images.
Reads grayscale images, and reverses the values if the image is detected to be
inverted.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def decode_music_score_png(contents):
"""Reads a music score image.
This reads a binary or grayscale image and takes the only channel. If the
image is detected to be inverted, the values will be flipped so that
the white background has value 255 and the black content has value 0.
Args:
contents: PNG data in a scalar string tensor.
Returns:
The music score image. A two-dimensional tensor (HW) of type uint8.
"""
with tf.name_scope("decode_music_score_png"):
contents = tf.convert_to_tensor(contents, name="contents")
image_t = tf.image.decode_png(contents, channels=1, dtype=tf.uint8)[:, :, 0]
def inverted_image():
# Sub op is not defined for uint8.
int32_image = tf.cast(image_t, tf.int32)
return tf.cast(255 - int32_image, tf.uint8)
threshold = 127
num_pixels = tf.shape(image_t)[0] * tf.shape(image_t)[1]
majority_dark = tf.greater(
tf.reduce_sum(tf.cast(image_t < threshold, tf.int32)), num_pixels // 2)
return tf.cond(majority_dark, inverted_image, lambda: image_t)
================================================
FILE: moonlight/models/base/BUILD
================================================
# Description:
# Common logic for all model training.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "batches",
srcs = ["batches.py"],
srcs_version = "PY2AND3",
deps = [
# absl dep
],
)
py_test(
name = "batches_test",
srcs = ["batches_test.py"],
srcs_version = "PY2AND3",
deps = [
":batches",
# disable_tf2
# absl dep
# tensorflow dep
],
)
py_library(
name = "glyph_patches",
srcs = ["glyph_patches.py"],
srcs_version = "PY2AND3",
deps = [
":batches",
":label_weights",
# absl dep
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/util:memoize",
# tensorflow dep
# tensorflow.contrib.estimator py dep
# tensorflow.contrib.image py dep
],
)
py_test(
name = "glyph_patches_test",
srcs = ["glyph_patches_test.py"],
srcs_version = "PY2AND3",
deps = [
":glyph_patches",
# disable_tf2
# numpy dep
# tensorflow dep
],
)
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
srcs_version = "PY2AND3",
deps = [
# six dep
# tensorflow dep
],
)
py_test(
name = "hyperparameters_test",
srcs = ["hyperparameters_test.py"],
srcs_version = "PY2AND3",
deps = [
":hyperparameters",
# disable_tf2
# numpy dep
],
)
py_library(
name = "label_weights",
srcs = ["label_weights.py"],
srcs_version = "PY2AND3",
deps = [
# absl dep
"//moonlight/protobuf:protobuf_py_pb2",
# six dep
# tensorflow dep
],
)
py_test(
name = "label_weights_test",
srcs = ["label_weights_test.py"],
srcs_version = "PY2AND3",
deps = [
":label_weights",
# disable_tf2
# tensorflow dep
],
)
================================================
FILE: moonlight/models/base/batches.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility for batching and limiting the dataset size according to flags."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_integer(
'dataset_shuffle_buffer_size', 10000,
'Shuffles this many entries in the dataset. 0 indicates no'
' shuffling.')
flags.DEFINE_integer(
'dataset_limit_size', None,
'Only take this many entries in the dataset (which may be repeated for some'
' number of training steps).')
flags.DEFINE_integer('dataset_batch_size', 32, 'Resulting batch size.')
def get_batched_tensor(dataset):
"""Gets the tensor representing a single batch from a `tf.data.Dataset`.
Batch and epoch options are passed on the command line.
Args:
dataset: A `tf.data.Dataset` containing single examples.
Returns:
A dict of tensors, which contains the concatenated features from each
example in a single batch. Each time the tensor is evaluated, it will
produce the next batch.
"""
if FLAGS.dataset_shuffle_buffer_size:
dataset = dataset.shuffle(buffer_size=FLAGS.dataset_shuffle_buffer_size)
if FLAGS.dataset_limit_size:
dataset = dataset.take(FLAGS.dataset_limit_size)
# Run through batches for multiple epochs, until the max num of train steps is
# exhausted.
dataset = dataset.batch(FLAGS.dataset_batch_size).repeat()
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
================================================
FILE: moonlight/models/base/batches_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for batches."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from moonlight.models.base import batches
import numpy as np
import tensorflow as tf
class BatchesTest(tf.test.TestCase):
def testBatching(self):
all_as = np.random.rand(1000, 2, 3)
all_bs = np.random.randint(0, 100, [1000], np.int32)
all_labels = np.random.randint(0, 5, [1000], np.int32)
random_dataset = tf.data.Dataset.from_tensor_slices(({
'a': tf.constant(all_as),
'b': tf.constant(all_bs)
}, tf.constant(all_labels)))
flags.FLAGS.dataset_shuffle_buffer_size = 0
batch_tensors = batches.get_batched_tensor(random_dataset)
with self.test_session() as sess:
batch = sess.run(batch_tensors)
# First batch.
self.assertEqual(len(batch), 2)
self.assertEqual(sorted(batch[0].keys()), ['a', 'b'])
batch_size = flags.FLAGS.dataset_batch_size
self.assertAllEqual(batch[0]['a'], all_as[:batch_size])
self.assertAllEqual(batch[0]['b'], all_bs[:batch_size])
self.assertAllEqual(batch[1], all_labels[:batch_size])
batch = sess.run(batch_tensors)
# Second batch.
self.assertEqual(len(batch), 2)
self.assertEqual(sorted(batch[0].keys()), ['a', 'b'])
batch_size = flags.FLAGS.dataset_batch_size
self.assertAllEqual(batch[0]['a'], all_as[batch_size:batch_size * 2])
self.assertAllEqual(batch[0]['b'], all_bs[batch_size:batch_size * 2])
self.assertAllEqual(batch[1], all_labels[batch_size:batch_size * 2])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/models/base/glyph_patches.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base patch-based glyph model.
For example, this accepts the staff patch k-means centroids emitted by
staffline_patches_kmeans_pipeline and labeled by kmeans_labeler.
This defines the input and signature of the model, and allows any type of
multi-class classifier using the normalized patches as input.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from absl import flags
from moonlight.models.base import batches
from moonlight.models.base import label_weights
from moonlight.protobuf import musicscore_pb2
from moonlight.util import memoize
import tensorflow as tf
from tensorflow.contrib import estimator as contrib_estimator
from tensorflow.contrib import image as contrib_image
from tensorflow.python.lib.io import file_io
from tensorflow.python.lib.io import tf_record
WEIGHT_COLUMN_NAME = 'weight'
FLAGS = flags.FLAGS
flags.DEFINE_string('train_input_patches', None,
'Glob of labeled patch TFRecords for training')
flags.DEFINE_string('eval_input_patches', None,
'Glob of labeled patch TFRecords for eval')
flags.DEFINE_string('model_dir', None, 'Output trained model directory')
flags.DEFINE_boolean(
'use_included_label_weight', False,
'Whether to multiply a "label_weight" feature included in the example by'
' the weight determined by the "label" value.')
flags.DEFINE_float(
'augmentation_x_shift_probability', 0.5,
'Probability of shifting the patch left or right by one pixel. The edge is'
' filled using the adjacent column. It is equally likely that the patch is'
' shifted left or right.')
flags.DEFINE_float(
'augmentation_max_rotation_degrees', 2.,
'Max rotation of the patch, in degrees. The rotation is selected uniformly'
' randomly from the range +- this value. A value of 0 implies no rotation.')
flags.DEFINE_integer('eval_throttle_secs', 60,
'Evaluate at at most this interval, in seconds.')
flags.DEFINE_integer(
'train_max_steps', 100000,
'Max steps for training. If 0, will train until the process is'
' interrupted.')
flags.DEFINE_integer('eval_steps', 500, 'Num steps to evaluate the model.')
flags.DEFINE_integer(
'exports_to_keep', 10,
'Keep the last N saved models (exported on each eval) before deleting'
' previous exports.')
flags.DEFINE_multi_string('classes_for_metrics', [
'NONE',
'CLEF_TREBLE',
'CLEF_BASS',
'NOTEHEAD_FILLED',
'NOTEHEAD_EMPTY',
'NOTEHEAD_WHOLE',
'SHARP',
'FLAT',
'NATURAL',
], 'Generate accuracy metrics for these class names.')
@memoize.MemoizedFunction
def read_patch_dimensions():
"""Reads the dimensions of the input patches from disk.
Parses the first example in the training set, which must have "height" and
"width" features.
Returns:
Tuple of (height, width) read from disk, using the glob passed to
--train_input_patches.
"""
for filename in file_io.get_matching_files(FLAGS.train_input_patches):
# If one matching file is empty, go on to the next file.
for record in tf_record.tf_record_iterator(filename):
example = tf.train.Example.FromString(record)
# Convert long (int64) to int, necessary for use in feature columns in
# Python 2.
patch_height = int(example.features.feature['height'].int64_list.value[0])
patch_width = int(example.features.feature['width'].int64_list.value[0])
return patch_height, patch_width
def input_fn(input_patches):
"""Defines the estimator input function.
Args:
input_patches: The input patches TFRecords pattern.
Returns:
A callable. Each invocation returns a tuple containing:
* A dict with a single key 'patch', and the patch tensor as a value.
* A scalar tensor with the patch label, as an integer.
"""
patch_height, patch_width = read_patch_dimensions()
dataset = tf.data.TFRecordDataset(file_io.get_matching_files(input_patches))
def parser(record):
"""Dataset parser function.
Args:
record: A single serialized Example proto tensor.
Returns:
A tuple of:
* A dict of features ('patch' and 'weight')
* A label tensor (int64 scalar).
"""
feature_types = {
'patch': tf.FixedLenFeature((patch_height, patch_width), tf.float32),
'label': tf.FixedLenFeature((), tf.int64),
}
if FLAGS.use_included_label_weight:
feature_types['label_weight'] = tf.FixedLenFeature((), tf.float32)
features = tf.parse_single_example(record, feature_types)
label = features['label']
weight = label_weights.weights_from_labels(label)
if FLAGS.use_included_label_weight:
# Both operands must be the same type (float32).
weight = tf.to_float(weight) * tf.to_float(features['label_weight'])
patch = _augment(features['patch'])
return {'patch': patch, WEIGHT_COLUMN_NAME: weight}, label
return batches.get_batched_tensor(dataset.map(parser))
def _augment(patch):
"""Performs multiple augmentations on the patch, helping to generalize."""
return _augment_rotation(_augment_shift(patch))
def _augment_shift(patch):
"""Augments the patch by possibly shifting it 1 pixel horizontally."""
with tf.name_scope('augment_shift'):
rand = tf.random_uniform(())
def shift_left():
return _shift_left(patch)
def shift_right():
return _shift_right(patch)
def identity():
return patch
shift_prob = min(1., FLAGS.augmentation_x_shift_probability)
return tf.cond(rand < shift_prob / 2, shift_left,
lambda: tf.cond(rand < shift_prob, shift_right, identity))
def _shift_left(patch):
patch = tf.convert_to_tensor(patch)
return tf.concat([patch[:, 1:], patch[:, -1:]], axis=1)
def _shift_right(patch):
patch = tf.convert_to_tensor(patch)
return tf.concat([patch[:, :1], patch[:, :-1]], axis=1)
def _augment_rotation(patch):
"""Augments the patch by rotating it by a small amount."""
max_rotation_radians = math.radians(FLAGS.augmentation_max_rotation_degrees)
rotation = tf.random_uniform((),
minval=-max_rotation_radians,
maxval=max_rotation_radians)
# Background is white (1.0) but tf.contrib.image.rotate currently always fills
# the edges with black (0). Invert the patch before rotating.
return 1. - contrib_image.rotate(
1. - patch, rotation, interpolation='BILINEAR')
def serving_fn():
"""Returns the ServingInputReceiver for the exported model.
Returns:
A ServingInputReceiver object which may be passed to
`Estimator.export_savedmodel`. A model saved using this receiver may be used
for running OMR.
"""
examples = tf.placeholder(tf.string, shape=[None])
patch_height, patch_width = read_patch_dimensions()
parsed = tf.parse_example(examples, {
'patch': tf.FixedLenFeature((patch_height, patch_width), tf.float32),
})
return tf.estimator.export.ServingInputReceiver(
features={'patch': parsed['patch']},
receiver_tensors=parsed['patch'],
receiver_tensors_alternatives={
'example': examples,
'patch': parsed['patch']
})
def create_patch_feature_column():
return tf.feature_column.numeric_column(
'patch', shape=read_patch_dimensions())
def multiclass_binary_metric(class_number, binary_metric, labels, predictions):
"""Wraps a binary metric for detecting a certain class (glyph type)."""
# Convert multiclass (integer) labels and predictions to booleans (for
# whether or not they are equal to the given class).
label_positive = tf.equal(labels, class_number)
predicted_positive = tf.equal(predictions['class_ids'], class_number)
return binary_metric(labels=label_positive, predictions=predicted_positive)
def metrics_fn(features, labels, predictions):
"""Metrics to be computed on every evaluation run, viewable in TensorBoard.
This function has the expected signature of a callable to be passed to
`tf.contrib.estimator.add_metrics`.
Args:
features: Dict of feature tensors.
labels: A tensor of example labels (ints).
predictions: Dict of prediction types. Has an entry "class_ids" which is
comparable to the ground truth in `labels`.
Returns:
A dict from metric name to TF metric.
"""
del features # Unused.
metrics = {
'mean_per_class_accuracy':
tf.metrics.mean_per_class_accuracy(
labels=labels,
predictions=predictions['class_ids'],
num_classes=len(musicscore_pb2.Glyph.Type.keys()),
),
}
for class_name in FLAGS.classes_for_metrics:
class_number = musicscore_pb2.Glyph.Type.Value(class_name)
metrics['class/{}_precision'.format(class_name)] = multiclass_binary_metric(
class_number, tf.metrics.precision, labels, predictions)
metrics['class/{}_recall'.format(class_name)] = multiclass_binary_metric(
class_number, tf.metrics.recall, labels, predictions)
return metrics
def train_and_evaluate(estimator):
tf.estimator.train_and_evaluate(
contrib_estimator.add_metrics(estimator, metrics_fn),
tf.estimator.TrainSpec(
input_fn=lambda: input_fn(FLAGS.train_input_patches),
max_steps=FLAGS.train_max_steps),
tf.estimator.EvalSpec(
input_fn=lambda: input_fn(FLAGS.eval_input_patches),
start_delay_secs=0,
throttle_secs=FLAGS.eval_throttle_secs,
steps=FLAGS.eval_steps,
exporters=[
tf.estimator.LatestExporter(
'exporter', serving_fn,
exports_to_keep=FLAGS.exports_to_keep),
]))
================================================
FILE: moonlight/models/base/glyph_patches_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for glyph_patches."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import flags
from moonlight.models.base import glyph_patches
import numpy as np
import tensorflow as tf
from tensorflow.python.lib.io import tf_record
class GlyphPatchesTest(tf.test.TestCase):
def testInputFn(self):
with tempfile.NamedTemporaryFile() as records_file:
with tf_record.TFRecordWriter(records_file.name) as records_writer:
flags.FLAGS.augmentation_x_shift_probability = 0
flags.FLAGS.augmentation_max_rotation_degrees = 0
example = tf.train.Example()
height = 5
width = 3
example.features.feature['height'].int64_list.value.append(height)
example.features.feature['width'].int64_list.value.append(width)
example.features.feature['patch'].float_list.value.extend(
range(height * width))
label = 1
example.features.feature['label'].int64_list.value.append(label)
for _ in range(3):
records_writer.write(example.SerializeToString())
flags.FLAGS.train_input_patches = records_file.name
batch_tensors = glyph_patches.input_fn(records_file.name)
with self.test_session() as sess:
batch = sess.run(batch_tensors)
self.assertAllEqual(
batch[0]['patch'],
np.arange(height * width).reshape(
(1, height, width)).repeat(3, axis=0))
self.assertAllEqual(batch[1], [label, label, label])
def testShiftLeft(self):
with self.test_session():
self.assertAllEqual(
# pyformat: disable
glyph_patches._shift_left([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10,
11]]).eval(),
[[1, 2, 3, 3], [5, 6, 7, 7], [9, 10, 11, 11]])
def testShiftRight(self):
with self.test_session():
self.assertAllEqual(
# pyformat: disable
glyph_patches._shift_right([[0, 1, 2, 3], [4, 5, 6, 7],
[8, 9, 10, 11]]).eval(),
[[0, 0, 1, 2], [4, 4, 5, 6], [8, 8, 9, 10]])
def testMulticlassBinaryMetric(self):
# pyformat: disable
# pylint: disable=bad-whitespace
labels = tf.constant([1, 1, 3, 2, 2, 2, 2])
predictions = dict(class_ids=tf.constant([1, 3, 2, 2, 4, 3, 2]))
_, precision_1 = glyph_patches.multiclass_binary_metric(
1, tf.metrics.precision, labels, predictions)
_, recall_1 = glyph_patches.multiclass_binary_metric(
1, tf.metrics.recall, labels, predictions)
_, precision_2 = glyph_patches.multiclass_binary_metric(
2, tf.metrics.precision, labels, predictions)
_, recall_2 = glyph_patches.multiclass_binary_metric(
2, tf.metrics.recall, labels, predictions)
_, precision_3 = glyph_patches.multiclass_binary_metric(
3, tf.metrics.precision, labels, predictions)
_, recall_3 = glyph_patches.multiclass_binary_metric(
3, tf.metrics.recall, labels, predictions)
with self.test_session() as sess:
sess.run(tf.local_variables_initializer())
# For class 1: 1 true positive and no false positives
self.assertEqual(1.0, precision_1.eval())
# For class 1: 1 true positive and 1 false negative
self.assertEqual(0.5, recall_1.eval())
# For class 2: 2 true positives and 1 false positive
self.assertAlmostEqual(2 / 3, precision_2.eval(), places=5)
# For class 2: 2 true positives and 2 false negatives
self.assertEqual(0.5, recall_2.eval())
# For class 3: No true positives
self.assertEqual(0, precision_3.eval())
self.assertEqual(0, recall_3.eval())
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/models/base/hyperparameters.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper which saves hyperparameters to a model collection.
The hyperparameters will be carefully tuned, and should be included in the
exported saved model to ensure reproducibility.
"""
# TODO(ringw): Try to get a standardized mechanism for saving params into
# TensorFlow.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
def estimator_with_saved_params(estimator, params):
"""Wraps an estimator with hyperparameters to be stored in the saved model.
Args:
estimator: A `tf.estimator.Estimator` instance.
params: A dict of string to constant value (string, number, or NumPy array).
Returns:
A wrapped `tf.estimator.Estimator`. The model of the new estimator extends
the wrapped model, and also includes the hyperparameters in a collection
where they can be inspected later.
Raises:
ValueError: If a hyperparameter value is None.
"""
# Validate parameters immediately, not in the model_fn.
for name, value in six.iteritems(params):
if value is None:
raise ValueError('Hyperparameter cannot be None: {}'.format(name))
# Estimator is mostly just a wrapper around the model_fn callable. Our wrapper
# just needs a callable that adds all the params to a collection, and then
# invokes the original callable.
def model_fn(features, labels, mode, params, config):
"""Wraps the delegate estimator model_fn.
Args:
features: A dict of string to Tensor. Features to classify.
labels: A Tensor with example labels, or None for prediction.
mode: The mode string for the estimator.
params: Passed through the newly constructed Estimator. These should be
identical to the outer function's params.
config: A TensorFlow estimator config object.
Returns:
An object holding the predictions, optimizer, etc.
"""
with tf.name_scope('params'):
for name, value in six.iteritems(params):
tf.add_to_collection('params', tf.constant(name=name, value=value))
# The Estimator model_fn property always returns a wrapped "public"
# model_fn. The public wrapper doesn't take "params", and passes the params
# from the Estimator constructor into the internal model_fn. Therefore, it
# only matters that we pass the params to the Estimator below.
return estimator.model_fn(features, labels, mode, config)
return tf.estimator.Estimator(
model_fn, model_dir=estimator.model_dir, params=params)
================================================
FILE: moonlight/models/base/hyperparameters_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.models.base import hyperparameters
import numpy as np
import tensorflow as tf
class HyperparametersTest(tf.test.TestCase):
def testSimpleModel(self):
learning_rate = np.float32(0.123)
params = {'learning_rate': learning_rate}
estimator = hyperparameters.estimator_with_saved_params(
tf.estimator.DNNClassifier(
hidden_units=[10],
feature_columns=[tf.feature_column.numeric_column('feature')]),
params)
with self.test_session():
# Build the estimator model.
estimator.model_fn(
features={'feature': tf.placeholder(tf.float32)},
labels=tf.placeholder(tf.float32),
mode='TRAIN',
config=None)
# We should be able to pull hyperparameters out of the TensorFlow graph.
# The entire graph will also be written to the saved model in training.
self.assertEqual(
learning_rate,
tf.get_default_graph().get_tensor_by_name(
'params/learning_rate:0').eval())
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/models/base/label_weights.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configures the weight (importance) of examples with a given label.
Any glyph type may be over- or under-represented in the examples, which would
hurt the precision and/or recall for that glyph type. When training, the
gradient for each example is multiplied by the weight, which scales the
parameter update for that example.
For an example custom weight, if naturals are often misclassified as sharps, and
not vice versa, we may want to increase the weight for NATURAL.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from moonlight.protobuf import musicscore_pb2
import numpy as np
import six
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string(
'label_weights', 'NONE=0.5',
'Example weights for patches of each label type. For example,'
' "NONE=0.01,FLAT=2.0" would weight "NONE" examples\' influence as 0.01,'
' "FLAT" examples as 2.0, and all other examples as 1.0.')
# The glyph types array must be large enough to hold the highest enum value.
GLYPH_TYPES_ARRAY_SIZE = 1 + max(
number for name, number in musicscore_pb2.Glyph.Type.items())
def parse_label_weights_array(weights_str=None):
"""Creates an array with all of the label weights.
Args:
weights_str: String of label name-weight pairs, separated by commas.
Defaults to the command-line flag.
Returns:
A NumPy array large enough to hold all of the glyph enum types. At the index
for a glyph enum value, we store the example weight, defaulting to 1.0.
Raises:
ValueError: If a glyph type is listed multiple times.
"""
weights_str = weights_str or FLAGS.label_weights
weights_array = np.ones(GLYPH_TYPES_ARRAY_SIZE)
if not weights_str:
return weights_array
weights = {}
for pair in weights_str.split(','):
name, glyph_weight_str = pair.split('=')
if name in weights:
raise ValueError('Duplicate weight: {}'.format(name))
weights[name] = float(glyph_weight_str)
for name, weight in six.iteritems(weights):
weights_array[musicscore_pb2.Glyph.Type.Value(name)] = weight
return weights_array
def weights_from_labels(labels, weights_str=None):
"""Determines the example weights from a tensor of example labels."""
with tf.name_scope('weights_from_labels'):
weights = tf.constant(
parse_label_weights_array(weights_str), name='label_weights')
return tf.gather(weights, labels, name='label_weights_lookup')
================================================
FILE: moonlight/models/base/label_weights_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for label_weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.models.base import label_weights
from moonlight.protobuf import musicscore_pb2
import tensorflow as tf
class LabelWeightsTest(tf.test.TestCase):
def testWeightsFromLabels(self):
g = musicscore_pb2.Glyph
labels = tf.constant(
[g.NONE, g.NONE, g.NOTEHEAD_FILLED, g.SHARP, g.FLAT, g.NATURAL])
weights = 'NONE=0.1,NATURAL=2.0,SHARP=0.5,NOTEHEAD_FILLED=0.8'
weights_tensor = label_weights.weights_from_labels(labels, weights)
with self.test_session():
self.assertAllEqual([0.1, 0.1, 0.8, 0.5, 1.0, 2.0], weights_tensor.eval())
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/models/glyphs_dnn/BUILD
================================================
# Description:
# Glyph patches DNN classifier.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "model",
srcs = ["model.py"],
srcs_version = "PY2AND3",
deps = [
# absl dep
"//moonlight/models/base:glyph_patches",
"//moonlight/models/base:hyperparameters",
"//moonlight/protobuf:protobuf_py_pb2",
],
)
py_binary(
name = "train",
srcs = ["train.py"],
srcs_version = "PY2AND3",
deps = [
":model",
# disable_tf2
# absl dep
"//moonlight/models/base:glyph_patches",
],
)
================================================
FILE: moonlight/models/glyphs_dnn/model.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the glyph patches DNN classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from moonlight.models.base import glyph_patches
from moonlight.models.base import hyperparameters
from moonlight.protobuf import musicscore_pb2
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_multi_integer(
'layer_dims', [20, 20],
'Dimensions of each hidden layer. --layer_dims=0 indicates logistic'
' regression (predictions directly connected to inputs through a sigmoid'
' layer).')
flags.DEFINE_string(
'activation_fn', 'sigmoid',
'The name of the function (under tf.nn) to apply after each layer.')
flags.DEFINE_float('learning_rate', 0.1, 'FTRL learning rate')
flags.DEFINE_float('l1_regularization_strength', 0.01, 'L1 penalty')
flags.DEFINE_float('l2_regularization_strength', 0, 'L2 penalty')
flags.DEFINE_float('dropout', 0, 'Dropout to apply to all hidden nodes.')
def get_flag_params():
"""Returns the hyperparameters specified by flags.
Returns:
A dict of hyperparameter names and values.
"""
layer_dims = FLAGS.layer_dims
if not any(layer_dims):
# Must pass a single layer of size 0 on the command line to indicate
# logistic regression (no hidden dims).
layer_dims = []
return {
'model_name':
'glyphs_dnn',
'layer_dims':
layer_dims,
'activation_fn':
FLAGS.activation_fn,
'learning_rate':
FLAGS.learning_rate,
'l1_regularization_strength':
FLAGS.l1_regularization_strength,
'l2_regularization_strength':
FLAGS.l2_regularization_strength,
'dropout':
FLAGS.dropout,
# Declared in glyph_patches.py.
'augmentation_x_shift_probability':
FLAGS.augmentation_x_shift_probability,
'augmentation_max_rotation_degrees':
FLAGS.augmentation_max_rotation_degrees,
'use_included_label_weight':
FLAGS.use_included_label_weight,
# Declared in label_weights.py.
'label_weights':
FLAGS.label_weights,
}
def create_estimator(params=None):
"""Returns the glyphs DNNClassifier estimator.
Args:
params: Optional hyperparameters, defaulting to command-line values.
Returns:
A DNNClassifier instance.
"""
params = params or get_flag_params()
if not params['layer_dims'] and params['activation_fn'] != 'sigmoid':
tf.logging.warning(
'activation_fn should be sigmoid for logistic regression. Got: %s',
params['activation_fn'])
activation_fn = getattr(tf.nn, params['activation_fn'])
estimator = tf.estimator.DNNClassifier(
params['layer_dims'],
feature_columns=[glyph_patches.create_patch_feature_column()],
weight_column=glyph_patches.WEIGHT_COLUMN_NAME,
n_classes=len(musicscore_pb2.Glyph.Type.keys()),
optimizer=tf.train.FtrlOptimizer(
learning_rate=params['learning_rate'],
l1_regularization_strength=params['l1_regularization_strength'],
l2_regularization_strength=params['l2_regularization_strength'],
),
activation_fn=activation_fn,
dropout=FLAGS.dropout,
model_dir=glyph_patches.FLAGS.model_dir,
)
return hyperparameters.estimator_with_saved_params(estimator, params)
================================================
FILE: moonlight/models/glyphs_dnn/train.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training the glyphs DNN classifier."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
from moonlight.models.base import glyph_patches
from moonlight.models.glyphs_dnn import model
import tensorflow as tf
FLAGS = flags.FLAGS
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
glyph_patches.train_and_evaluate(model.create_estimator())
if __name__ == '__main__':
app.run(main)
================================================
FILE: moonlight/music/BUILD
================================================
# Description:
# General music theory for OMR.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "constants",
srcs = ["constants.py"],
srcs_version = "PY2AND3",
)
================================================
FILE: moonlight/music/constants.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants for music theory in OMR."""
# The indices of the pitch classes in a major scale.
MAJOR_SCALE = [0, 2, 4, 5, 7, 9, 11]
NUM_SEMITONES_PER_OCTAVE = 12
# These constants are coincidentally equal.
# The size of the perfect fifth interval.
NUM_SEMITONES_IN_PERFECT_FIFTH = 7
# The number of pitch classes present in a diatonic scale (e.g. the major scale)
NUM_NOTES_IN_DIATONIC_SCALE = 7
# The consecutive base notes of a key signature are each separated by a fifth,
# or 7 semitones.
CIRCLE_OF_FIFTHS = [0, 7, 2, 9, 4, 11, 6, 1, 8, 3, 10, 5]
# CIRCLE_OF_FIFTHS is declared as a constant for clarity, but can be generated
# from:
# CIRCLE_OF_FIFTHS = [
# (i * NUM_SEMITONES_IN_PERFECT_FIFTH) % NUM_SEMITONES_PER_OCTAVE
# for i in range(NUM_SEMITONES_PER_OCTAVE)
# ]
================================================
FILE: moonlight/omr.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs OMR and outputs a Score or NoteSequence message."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import time
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from moonlight import conversions
from moonlight import engine
from moonlight.glyphs import saved_classifier_fn
FLAGS = flags.FLAGS
VALID_OUTPUT_TYPES = ['MusicXML', 'NoteSequence', 'Score']
# The name of the png file contents tensor.
PNG_CONTENTS_TENSOR = 'png_contents'
flags.DEFINE_string(
'glyphs_saved_model', None,
'Path to the patch-based glyph classifier saved model dir. Defaults to the'
' included KNN classifier.')
flags.DEFINE_string('output', '/dev/stdout',
'Path to write the output text-format or binary proto.')
flags.DEFINE_string('output_type', 'Score',
'Which output type to produce (Score or NoteSequence).')
flags.DEFINE_boolean('text_format', True, 'Whether the output is text format.')
def run(input_pngs, glyphs_saved_model=None, output_notesequence=False):
"""Runs OMR over a list of input images.
Args:
input_pngs: A list of PNG filenames to process.
glyphs_saved_model: Optional saved model dir to override the included model.
output_notesequence: Whether to return a NoteSequence, as opposed to a Score
containing Pages with Glyphs.
Returns:
A NoteSequence message, or a Score message holding Pages for each input
image (with their detected Glyphs).
"""
return engine.OMREngine(
saved_classifier_fn.build_classifier_fn(glyphs_saved_model)).run(
input_pngs, output_notesequence=output_notesequence)
def main(argv):
if FLAGS.output_type not in VALID_OUTPUT_TYPES:
raise ValueError('output_type "%s" not in allowed types: %s' %
(FLAGS.output_type, VALID_OUTPUT_TYPES))
# Exclude argv[0], which is the current binary.
patterns = argv[1:]
if not patterns:
raise ValueError('PNG file glob(s) must be specified')
input_paths = []
for pattern in patterns:
pattern_paths = file_io.get_matching_files(pattern)
if not pattern_paths:
raise ValueError('Pattern "%s" failed to match any files' % pattern)
input_paths.extend(pattern_paths)
start = time.time()
output = run(
input_paths,
FLAGS.glyphs_saved_model,
output_notesequence=FLAGS.output_type == 'NoteSequence')
end = time.time()
sys.stderr.write('OMR elapsed time: %.2f\n' % (end - start))
if FLAGS.output_type == 'MusicXML':
output_bytes = conversions.score_to_musicxml(output)
else:
if FLAGS.text_format:
output_bytes = text_format.MessageToString(output).encode('utf-8')
else:
output_bytes = output.SerializeToString()
file_io.write_string_to_file(FLAGS.output, output_bytes)
if __name__ == '__main__':
tf.app.run()
================================================
FILE: moonlight/omr_endtoend_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple end-to-end test for OMR on the sample image."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl.testing import absltest
import librosa
from lxml import etree
import numpy as np
from PIL import Image
from protobuf import music_pb2
from tensorflow.python.platform import resource_loader
from moonlight import conversions
from moonlight import engine
class OmrEndToEndTest(absltest.TestCase):
def setUp(self):
self.engine = engine.OMREngine()
def testNoteSequence(self):
filename = os.path.join(resource_loader.get_data_files_path(),
'testdata/IMSLP00747-000.png')
notes = self.engine.run(filename, output_notesequence=True)
# TODO(ringw): Fix the extra note that is detected before the actual
# first eighth note.
self.assertEqual(librosa.note_to_midi('C4'), notes.notes[1].pitch)
self.assertEqual(librosa.note_to_midi('D4'), notes.notes[2].pitch)
self.assertEqual(librosa.note_to_midi('E4'), notes.notes[3].pitch)
self.assertEqual(librosa.note_to_midi('F4'), notes.notes[4].pitch)
self.assertEqual(librosa.note_to_midi('D4'), notes.notes[5].pitch)
self.assertEqual(librosa.note_to_midi('E4'), notes.notes[6].pitch)
self.assertEqual(librosa.note_to_midi('C4'), notes.notes[7].pitch)
def testBeams_sixteenthNotes(self):
filename = os.path.join(resource_loader.get_data_files_path(),
'testdata/IMSLP00747-000.png')
notes = self.engine.run([filename], output_notesequence=True)
def _sixteenth_note(pitch, start_time):
return music_pb2.NoteSequence.Note(
pitch=librosa.note_to_midi(pitch),
start_time=start_time,
end_time=start_time + 0.25)
# TODO(ringw): Fix the phantom quarter note detected before the treble
# clef, and the eighth rest before the first note (should be sixteenth).
self.assertIn(_sixteenth_note('C4', 1.5), notes.notes)
self.assertIn(_sixteenth_note('D4', 1.75), notes.notes)
self.assertIn(_sixteenth_note('E4', 2), notes.notes)
self.assertIn(_sixteenth_note('F4', 2.25), notes.notes)
# TODO(ringw): The second D and E are detected with only one beam, even
# though they are connected to the same beams as the F before them and the
# C after them. Fix.
def testIMSLP00747_000_structure_barlines(self):
page = self.engine.run(
os.path.join(resource_loader.get_data_files_path(),
'testdata/IMSLP00747-000.png')).page[0]
self.assertEqual(len(page.system), 6)
self.assertEqual(len(page.system[0].staff), 2)
self.assertEqual(len(page.system[0].bar), 4)
self.assertEqual(len(page.system[1].staff), 2)
self.assertEqual(len(page.system[1].bar), 5)
self.assertEqual(len(page.system[2].staff), 2)
self.assertEqual(len(page.system[2].bar), 5)
self.assertEqual(len(page.system[3].staff), 2)
self.assertEqual(len(page.system[3].bar), 4)
self.assertEqual(len(page.system[4].staff), 2)
self.assertEqual(len(page.system[4].bar), 5)
self.assertEqual(len(page.system[5].staff), 2)
self.assertEqual(len(page.system[5].bar), 5)
for system in page.system:
for staff in system.staff:
self.assertEqual(staff.staffline_distance, 16)
def testMusicXML(self):
filename = os.path.join(resource_loader.get_data_files_path(),
'testdata/IMSLP00747-000.png')
score = self.engine.run([filename])
num_measures = sum(
len(system.bar) - 1 for page in score.page for system in page.system)
musicxml = etree.fromstring(conversions.score_to_musicxml(score))
self.assertEqual(2, len(musicxml.findall('part')))
self.assertEqual(num_measures,
len(musicxml.find('part[1]').findall('measure')))
def testProcessImage(self):
pil_image = Image.open(
os.path.join(resource_loader.get_data_files_path(),
'testdata/IMSLP00747-000.png')).convert('L')
arr = np.array(pil_image.getdata(), np.uint8).reshape(
# Size is (width, height).
pil_image.size[1],
pil_image.size[0])
page = self.engine.process_image(arr)
self.assertEqual(6, len(page.system))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/omr_regression_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests OMR with corpus scores from Placer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import re
from absl.app import flags
from absl.testing import absltest
from moonlight import engine
from moonlight.protobuf import musicscore_pb2
from moonlight.score import measures
from moonlight.score import reader
NOTEHEAD_FILLED = musicscore_pb2.Glyph.NOTEHEAD_FILLED
IMSLP_FILENAME = re.compile('IMSLP([0-9]{5,})-[0-9]{3}.png')
flags.DEFINE_string('corpus_dir', 'corpus', 'Path to the extracted IMSLP pngs.')
FLAGS = flags.FLAGS
class OmrRegressionTest(absltest.TestCase):
def testIMSLP01963_106_multipleStaffSizes(self):
page = engine.OMREngine().run(_get_imslp_path('IMSLP01963-106.png')).page[0]
self.assertEqual(len(page.system), 3)
for system in page.system:
self.assertEqual(len(system.staff), 4)
self.assertEqual(system.staff[0].staffline_distance, 14)
self.assertEqual(system.staff[1].staffline_distance, 14)
self.assertEqual(system.staff[2].staffline_distance, 22)
self.assertEqual(system.staff[3].staffline_distance, 22)
def testIMSLP00823_000_structure(self):
page = engine.OMREngine().run(_get_imslp_path('IMSLP00823-000.png')).page[0]
self.assertEqual(len(page.system), 6)
self.assertEqual(len(page.system[0].staff), 2)
self.assertEqual(len(page.system[0].bar), 7)
self.assertEqual(len(page.system[1].staff), 2)
self.assertEqual(len(page.system[1].bar), 7)
self.assertEqual(len(page.system[2].staff), 2)
self.assertEqual(len(page.system[2].bar), 6)
self.assertEqual(len(page.system[3].staff), 2)
self.assertEqual(len(page.system[3].bar), 6)
self.assertEqual(len(page.system[4].staff), 2)
# TODO(ringw): Fix barline detection here.
# self.assertEqual(len(page.system[4].bar), 6)
self.assertEqual(len(page.system[5].staff), 2)
self.assertEqual(len(page.system[5].bar), 6)
def testIMSLP00823_008_mergeStandardAndBeginRepeatBars(self):
page = engine.OMREngine().run(_get_imslp_path('IMSLP00823-008.png')).page[0]
self.assertEqual(len(page.system), 6)
self.assertEqual(len(page.system[0].staff), 2)
# TODO(ringw): Fix barline detection here.
# self.assertEqual(len(page.system[0].bar), 6)
self.assertEqual(len(page.system[1].staff), 2)
# self.assertEqual(len(page.system[1].bar), 6)
self.assertEqual(len(page.system[2].staff), 2)
self.assertEqual(len(page.system[2].bar), 7)
self.assertEqual(len(page.system[3].staff), 2)
self.assertEqual(len(page.system[3].bar), 6)
self.assertEqual(len(page.system[4].staff), 2)
self.assertEqual(len(page.system[4].bar), 6)
# TODO(ringw): Detect BEGIN_REPEAT_BAR here.
self.assertEqual(page.system[4].bar[0].type,
musicscore_pb2.StaffSystem.Bar.END_BAR)
self.assertEqual(page.system[4].bar[1].type,
musicscore_pb2.StaffSystem.Bar.STANDARD_BAR)
self.assertEqual(len(page.system[5].staff), 2)
self.assertEqual(len(page.system[5].bar), 7)
def testIMSLP39661_keySignature_CSharpMinor(self):
page = engine.OMREngine().run(_get_imslp_path('IMSLP39661-000.png')).page[0]
score_reader = reader.ScoreReader()
# One of the sharps in the first system is heavily obscured.
score_reader.read_system(page.system[1])
treble_sig = score_reader.score_state.staves[0].get_key_signature()
self.assertEqual(treble_sig.get_type(), musicscore_pb2.Glyph.SHARP)
self.assertEqual(len(treble_sig), 4)
bass_sig = score_reader.score_state.staves[1].get_key_signature()
self.assertEqual(bass_sig.get_type(), musicscore_pb2.Glyph.SHARP)
# TODO(ringw): Get glyphs detected correctly in the bass signature.
# self.assertEqual(len(bass_sig), 4)
def testIMSLP00023_015_doubleNoteDots(self):
"""Tests note dots in system[1].staff[1] of the image."""
page = engine.OMREngine().run(_get_imslp_path('IMSLP00023-015.png')).page[0]
self.assertEqual(len(page.system), 6)
system = page.system[1]
system_measures = measures.Measures(system)
staff = system.staff[1]
# All dotted notes in the first measure belong to one chord, and are
# double-dotted.
double_dotted_notes = [
glyph for glyph in staff.glyph
if system_measures.get_measure(glyph) == 0 and len(glyph.dot) == 2
]
for note in double_dotted_notes:
self.assertEqual(len(note.beam), 1)
# Double-dotted eighth note duration.
self.assertEqual(note.note.end_time - note.note.start_time,
.5 + .25 + .125)
double_dotted_note_ys = [glyph.y_position for glyph in double_dotted_notes]
self.assertIn(-6, double_dotted_note_ys)
self.assertIn(-3, double_dotted_note_ys)
self.assertIn(-1, double_dotted_note_ys)
self.assertTrue(
set(double_dotted_note_ys).issubset([-6, -3, -1, +3, +4]),
'No unexpected double-dotted noteheads')
# TODO(ringw): Notehead at +4 picks up extra dots (4 total). The dots
# should be in a horizontal line, and we should discard other dots.
# There should only be one notehead at +4 with 2 or more dots.
self.assertEqual(
len([
glyph for glyph in staff.glyph
if system_measures.get_measure(glyph) == 0 and glyph.type ==
NOTEHEAD_FILLED and glyph.y_position == +4 and len(glyph.dot) >= 2
]), 1)
# All dotted notes in the second measure belong to one chord, and are
# single-dotted.
single_dotted_notes = [
glyph for glyph in staff.glyph
if system_measures.get_measure(glyph) == 1 and len(glyph.dot) == 1
]
for note in single_dotted_notes:
if note.y_position == +2:
# TODO(ringw): Detect the beam for this notehead. Its stem is too
# short.
continue
self.assertEqual(len(note.beam), 1)
# Single-dotted eighth note duration.
self.assertEqual(note.note.end_time - note.note.start_time, .75)
single_dotted_note_ys = [glyph.y_position for glyph in single_dotted_notes]
self.assertIn(-5, single_dotted_note_ys)
self.assertIn(-3, single_dotted_note_ys)
self.assertIn(0, single_dotted_note_ys)
self.assertIn(+2, single_dotted_note_ys)
# TODO(ringw): Detect the dot for the note at y position +4.
self.assertTrue(set(single_dotted_note_ys).issubset([-5, -3, 0, +2, +4]))
def _get_imslp_path(filename):
m = re.match(IMSLP_FILENAME, filename)
bucket = int(m.group(1)) // 1000
return os.path.join(FLAGS.corpus_dir, str(bucket), filename)
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/page_processors.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Callables that process Page messages.
Each processor takes a Page and returns a possibly modified Page. It may modify
the Page message in place, and return the same message.
The purpose of a processor is to perform simple inference on elements already in
the Page and in the Structure. Processing should not be CPU-intensive, or the
heavy lifting needs to be implemented in TensorFlow for efficiency.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.glyphs import glyph_types
from moonlight.glyphs import note_dots
from moonlight.glyphs import repeated
from moonlight.staves import staff_processor
from moonlight.structure import barlines
from moonlight.structure import beam_processor
from moonlight.structure import section_barlines
from moonlight.structure import stems
def create_processors(structure, staffline_extractor=None):
"""Generator for the processors to be applied to the Page in order.
Args:
structure: The computed `Structure`.
staffline_extractor: The staffline extractor to use for scaling glyph x
coordinates. Optional.
Yields:
Callables which accept a single `Page` as an argument, and return it
(either modifying in place or returning a modified copy).
"""
yield staff_processor.StaffProcessor(structure, staffline_extractor)
yield stems.Stems(structure)
yield beam_processor.BeamProcessor(structure)
yield note_dots.NoteDots(structure)
yield CenteredRests()
yield repeated.FixRepeatedRests()
yield barlines.Barlines(structure)
yield section_barlines.SectionBarlines(structure)
yield section_barlines.MergeStandardAndBeginRepeatBars(structure)
def process(page, structure, staffline_extractor=None):
for processor in create_processors(structure, staffline_extractor):
page = processor.apply(page)
return page
# TODO(ringw): Add a helper for processors that filter the glyphs like this.
class CenteredRests(object):
def apply(self, page):
"""Rests should be centered on the staff, assuming a single voice."""
for system in page.system:
for staff in system.staff:
to_remove = []
for glyph in staff.glyph:
if glyph_types.is_rest(glyph) and abs(glyph.y_position) > 2:
to_remove.append(glyph)
for glyph in to_remove:
staff.glyph.remove(glyph)
return page
================================================
FILE: moonlight/pipeline/BUILD
================================================
# Description:
# OMR pipelines and Apache Beam utilities.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "pipeline_flags",
srcs = ["pipeline_flags.py"],
srcs_version = "PY2AND3",
deps = [
# absl dep
# apache-beam dep
],
)
================================================
FILE: moonlight/pipeline/pipeline_flags.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configures the Apache Beam runner from the command line in pipelines.
Command-line flags for particular runners can be added here later, if necessary.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import apache_beam
FLAGS = flags.FLAGS
flags.DEFINE_string(
'runner', 'DirectRunner',
'The class name of the Apache Beam runner to use in the pipeline.')
def create_pipeline(**kwargs):
return apache_beam.Pipeline(FLAGS.runner, **kwargs)
================================================
FILE: moonlight/protobuf/BUILD
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
py_proto_library(
name = "protobuf_py_pb2",
srcs = glob(["*.proto"]),
deps = ["@magenta//protobuf:music_py_pb2"],
default_runtime = "@com_google_protobuf//:protobuf_python",
protoc = "@com_google_protobuf//:protoc",
srcs_version = "PY2AND3",
)
================================================
FILE: moonlight/protobuf/groundtruth.proto
================================================
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package tensorflow.moonlight;
// MusicXML ground truth corpus. Each MusicXML file corresponds to one or more
// music score page images. The output of OMR on the pages is compared against
// the MusicXML ground truth to evaluate the OMR engine.
message GroundTruth {
optional string title = 1;
optional string ground_truth_filename = 2;
// Has one or more page images corresponding to the ground truth.
repeated PageSpec page_spec = 3;
enum Tag {
// A parsed tag that was unrecognized.
UNKNOWN_TAG = 0;
// One voice per staff, with chords allowed in each voice. If not present,
// all staves are assumed to be polyphonic, which is not supported, and the
// score may be excluded from evaluation.
MONOPHONIC = 1;
// One staff per staff system.
SINGLE_STAFF = 2;
// Piano/grand staff.
PIANO = 3;
// One voice (usually connected by a beam) has notes in multiple staves.
// This is not supported by evaluation, and the score may be excluded from
// evaluation.
VOICES_CROSS_STAVES = 4;
// The score contains chords in more than one measure (chords are often
// present in the last measure, and we want to select scores with more than
// one measure).
CHORDS = 5;
}
repeated Tag tag = 4;
}
message PageSpec {
optional string filename = 1;
// TODO(ringw): Allow choosing a range of staff systems within one page.
// The score may have one movement end and the next movement start on the same
// page. The MusicXML ground truth will likely have one file per movement.
}
================================================
FILE: moonlight/protobuf/musicscore.proto
================================================
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package tensorflow.moonlight;
// TODO(ringw): Fix python import issue, then load magenta with a prefix of
// "magenta" instead of "magenta/magenta". The import here should be
// "magenta/music/protobuf/music.proto".
import "protobuf/music.proto";
// Alpha-stage OMR page model. Most protos (except Glyph.Type) are subject to
// backwards-incompatible changes (e.g. renumbering or replacing fields).
// A score, containing multiple pages.
message Score {
repeated Page page = 1;
}
// A page, holding all detected objects.
message Page {
repeated StaffSystem system = 1;
}
// A staff system, holding one or more staves connected by barline(s).
// Contains zero or more bars, x coordinates where there is a vertical barline.
// If there are no bars, the staff contains a single measure spanning the width
// of the page. If there is a single bar, the staff contains a single measure
// spanning [bar[0], width). Otherwise, for n bars, there are n - 1 measures
// bounded by each adjacent pair of bars.
message StaffSystem {
repeated Staff staff = 1;
repeated Bar bar = 2;
message Bar {
optional int32 x = 1;
optional Type type = 2;
enum Type {
UNKNOWN_BAR = 0;
STANDARD_BAR = 1;
DOUBLE_BAR = 2;
END_BAR = 3;
BEGIN_REPEAT_BAR = 4;
END_REPEAT_BAR = 5;
BEGIN_AND_END_REPEAT_BAR = 6;
}
}
}
// A staff is represented by the distance between consecutive horizontal lines,
// and line segments that make up the center line (3rd out of 5 lines).
message Staff {
// The vertical distance between the horizontal lines that make up the staff.
// This is assumed to be constant for printed music scores.
optional int64 staffline_distance = 1;
// The approximately horizontal staff center line (third line of the staff).
// The other lines must be a multiple of staffline_distance above and below.
// Currently, the line always starts at x = 0 and ends at x = width - 1.
repeated Point center_line = 2;
// Glyphs which belong to the staff, sorted by x coordinate.
repeated Glyph glyph = 3;
}
// A musical element which is attached to a staff. The y_position represents the
// the staffline or staff space that a glyph is placed on. The center line
// (third staff line) has position 0, and the y_position increases with
// decreasing y coordinates (towards musically higher notes).
// For example, in treble clef, position 0 is B4, position -6 is C4, and
// position 1 is C5. For more details, see:
// g3doc/learning/brain/research/magenta/models/omr/g3doc/glyphs.md
message Glyph {
optional Type type = 1;
optional int64 x = 2;
optional int64 y_position = 3;
// Noteheads may be attached to a Stem. The reader joins noteheads with the
// same Stem as chords, which makes them all have the same timing.
optional LineSegment stem = 4;
// The Stem may intersect with Beam(s). A NOTEHEAD_FILLED has a duration of 1
// which is halved for each incident Beam.
repeated LineSegment beam = 5;
// Dots are attached to noteheads and alter their duration. Each dot adds half
// of the previous dot's value (or half the original duration, for the first
// dot) to the final duration.
repeated Point dot = 6;
// The Glyph may be detected as a Note by ScoreReader. The Note is stored as
// part of the Glyph.
optional magenta.NoteSequence.Note note = 7;
// Glyph Types are used as labels for the Examples in our corpus. Don't change
// or reuse the numbers!
enum Type {
// Default value.
UNKNOWN_TYPE = 0;
// No glyph at the given position. This is used for classification, where
// we will evaluate every pixel as a possible center point for a glyph.
NONE = 1;
// G clef. Typically centered on the second line from the bottom of the
// staff (G4 in treble clef). May be shifted in other, uncommon clefs.
CLEF_TREBLE = 2;
// F clef. Typically centered on the fourth line from the bottom of the
// staff (F3 in bass clef). May be shifted in other, uncommon clefs.
CLEF_BASS = 3;
// C clef is the clef symbol used for alto, tenor, and other clefs. In each
// such clef, the line that the C clef is centered on represents C4.
CLEF_C = 13;
// "Common time" is a stylized "c" which is equivalent to a 4/4 time
// signature. Note that numeral time signatures are not detected here. They
// will need to be detected in a post-processing step using OCR.
TIME_COMMON = 18;
// "Cut time" has a vertical line through the "common time" symbol, and is
// equivalent to 2/2 time.
TIME_CUT = 19;
NOTEHEAD_FILLED = 4;
NOTEHEAD_EMPTY = 5;
NOTEHEAD_WHOLE = 6;
// Note that whole and half rests will likely not be labeled for the
// classifier. They can easily be detected as rectangular connected
// components after staff removal, and are distinguished by their absolute
// position on the staff. However, the two currently look identical on the
// extracted patches because we use staff removal.
REST_WHOLE = 14;
REST_HALF = 15;
REST_QUARTER = 7;
REST_EIGHTH = 8;
REST_SIXTEENTH = 9;
// TODO(ringw): Sixty-fourth rests don't fit within our typical patch
// size (3 times the staffline distance), so they can't be detected under
// the current model.
REST_THIRTYSECOND = 17;
FLAT = 10;
SHARP = 11;
DOUBLE_SHARP = 16;
NATURAL = 12;
// Next tag: 20
}
}
message LineSegment {
optional Point start = 1;
optional Point end = 2;
}
message Rect {
optional Point top_left = 1;
optional Point bottom_right = 2;
}
message Point {
optional int64 x = 1;
optional int64 y = 2;
}
================================================
FILE: moonlight/score/BUILD
================================================
# Description:
# Score reading for OMR.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "score",
deps = [
":measures",
":reader",
],
)
py_library(
name = "measures",
srcs = ["measures.py"],
deps = ["//moonlight/protobuf:protobuf_py_pb2"],
)
py_library(
name = "reader",
srcs = ["reader.py"],
deps = [
":measures",
# absl dep
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/score/elements:clef",
"//moonlight/score/state",
],
)
py_test(
name = "reader_test",
srcs = ["reader_test.py"],
deps = [
":reader",
# absl/testing dep
# librosa dep
"//moonlight/conversions",
"//moonlight/protobuf:protobuf_py_pb2",
"@magenta//protobuf:music_py_pb2",
],
)
================================================
FILE: moonlight/score/elements/BUILD
================================================
# Description:
# Score elements which encapsulate custom logic for score reading.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "clef",
srcs = ["clef.py"],
deps = [
# librosa dep
"//moonlight/music:constants",
"//moonlight/protobuf:protobuf_py_pb2",
],
)
py_test(
name = "clef_test",
srcs = ["clef_test.py"],
deps = [
":clef",
# absl/testing dep
# librosa dep
],
)
py_library(
name = "key_signature",
srcs = ["key_signature.py"],
deps = [
":clef",
# librosa dep
"//moonlight/music:constants",
"//moonlight/protobuf:protobuf_py_pb2",
],
)
py_test(
name = "key_signature_test",
srcs = ["key_signature_test.py"],
deps = [
":clef",
":key_signature",
# absl/testing dep
"//moonlight/protobuf:protobuf_py_pb2",
],
)
================================================
FILE: moonlight/score/elements/clef.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Clef logic for OMR.
A Clef object maps y positions on the staff to the MIDI pitch of the natural
note at the y position.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import librosa
from moonlight.music import constants
from moonlight.protobuf import musicscore_pb2
class Clef(object):
"""Represents a clef which maps y positions to MIDI notes.
Attributes:
center_line_pitch: A _ScalePitch representing the center line (3rd line of
the staff).
"""
center_line_pitch = None
def y_position_to_midi(self, y_position):
return (self.center_line_pitch + y_position).midi
class TrebleClef(Clef):
"""Represents a treble clef."""
def __init__(self):
self.center_line_pitch = _ScalePitch(constants.MAJOR_SCALE,
librosa.note_to_midi('B4'))
self.glyph = musicscore_pb2.Glyph.CLEF_TREBLE
class BassClef(Clef):
"""Represents a bass clef."""
def __init__(self):
self.center_line_pitch = _ScalePitch(constants.MAJOR_SCALE,
librosa.note_to_midi('D3'))
self.glyph = musicscore_pb2.Glyph.CLEF_BASS
class _ScalePitch(object):
"""A natural note which can be offset to get another note.
Attributes:
scale: The scale which this pitch is based on. A list of MIDI pitch values
spanning one octave.
index: The index of the pitch's pitch class within the scale.
octave: The index of the octave that the pitch is in, relative to the octave
spanning the scale notes.
"""
def __init__(self, scale, midi):
self.scale = scale
self.index = scale.index(midi % constants.NUM_SEMITONES_PER_OCTAVE)
self.octave = (midi - scale[0]) // 12
@property
def midi(self):
"""The MIDI value for the pitch."""
notes_per_octave = constants.NUM_SEMITONES_PER_OCTAVE
return self.scale[self.index] + notes_per_octave * self.octave
@property
def pitch_index(self):
"""The index of the pitch in the C major scale."""
return self.index + len(self.scale) * self.octave
def __add__(self, interval):
"""Returns the natural note `interval` away on `self.scale`."""
pitch = _ScalePitch(self.scale, self.midi)
pitch_index = self.pitch_index + interval
pitch.index = pitch_index % len(self.scale)
pitch.octave = pitch_index // len(self.scale)
return pitch
================================================
FILE: moonlight/score/elements/clef_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the clefs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import librosa
from moonlight.score.elements import clef
class ClefTest(absltest.TestCase):
def testTrebleClef(self):
self.assertEqual(clef.TrebleClef().y_position_to_midi(-8),
librosa.note_to_midi('A3'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(-6),
librosa.note_to_midi('C4'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(0),
librosa.note_to_midi('B4'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(1),
librosa.note_to_midi('C5'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(3),
librosa.note_to_midi('E5'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(4),
librosa.note_to_midi('F5'))
self.assertEqual(clef.TrebleClef().y_position_to_midi(14),
librosa.note_to_midi('B6'))
def testBassClef(self):
self.assertEqual(clef.BassClef().y_position_to_midi(-10),
librosa.note_to_midi('A1'))
self.assertEqual(clef.BassClef().y_position_to_midi(-7),
librosa.note_to_midi('D2'))
self.assertEqual(clef.BassClef().y_position_to_midi(-5),
librosa.note_to_midi('F2'))
self.assertEqual(clef.BassClef().y_position_to_midi(-1),
librosa.note_to_midi('C3'))
self.assertEqual(clef.BassClef().y_position_to_midi(0),
librosa.note_to_midi('D3'))
self.assertEqual(clef.BassClef().y_position_to_midi(6),
librosa.note_to_midi('C4'))
self.assertEqual(clef.BassClef().y_position_to_midi(8),
librosa.note_to_midi('E4'))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/score/elements/key_signature.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Music key signature inference.
The accidentals classes are Accidentals, which are reset for each new measure,
and KeySignature, which is persisted for a staff at a time (because it is
expected to be repeated on each new staff). The key signature must follow the
expected pattern, or subsequent accidentals will fail to be added to it, and
should be added to the Accidentals instead.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import librosa
from moonlight.music import constants
from moonlight.protobuf import musicscore_pb2
import six
Glyph = musicscore_pb2.Glyph # pylint: disable=invalid-name
class _BaseAccidentals(object):
"""Holds accidentals which are not part of the key signature."""
def __init__(self, clef, accidentals=None):
self.clef = clef
self._accidentals = dict(accidentals or {})
def _normalize_position(self, position):
"""No octave normalization.
Accidentals only apply to the current octave, whereas `KeySignature`
overrides this to make its accidentals octave-invariant.
Args:
position: The vertical staff y position.
Returns:
The normalized position.
"""
return position
def get_accidental_for_position(self, position):
return self._accidentals.get(self._normalize_position(position), Glyph.NONE)
class Accidentals(_BaseAccidentals):
"""Simple map of staff y position to accidental value."""
def __init__(self, clef):
super(Accidentals, self).__init__(clef)
def put(self, position, accidental):
self._accidentals[position] = accidental
class KeySignature(_BaseAccidentals):
"""Music key signature.
Tracks the expected order of accidentals in a key signature. If we detect that
an accidental does not match the next expected accidental, it will be treated
as a normal accidental and not part of the key signature.
"""
def _normalize_position(self, position):
"""Normalize base notes by octave for the key signature.
The key signature contains one accidental for one base notes, which applies
to the same pitch class in all octaves.
Args:
position: The staff y position of the glyph.
Returns:
The base note normalized by octave. This causes an accidental in the key
signature to apply to the same note in all octaves.
"""
return position % len(constants.MAJOR_SCALE)
def try_put(self, position, accidental):
"""Adds an accidental to the key signature if applicable.
Args:
position: The accidental glyph y position.
accidental: The accidental glyph type.
Returns:
True if the accidental was successfully added to the key signature. False
if the key signature would be invalid when adding the new accidental.
"""
can_put = self._can_put(position, accidental)
if can_put:
self._accidentals[self._normalize_position(position)] = accidental
return can_put
def _can_put(self, position, accidental):
if not self._accidentals:
pitch_class = self.clef.y_position_to_midi(position) % 12
return (accidental in _KEY_SIGNATURE_PITCH_CLASS_LIST and
pitch_class == _KEY_SIGNATURE_PITCH_CLASS_LIST[accidental][0])
return (position, accidental) == self.get_next_accidental()
def get_next_accidental(self):
"""Predicts the next accidental which would be present in the key signature.
Cannot predict the next accidental if the key signature is currently empty
(C major), because the key could contain either sharps or flats.
Returns:
The expected y position of the next accidental if possible, or None.
The expected accidental glyph type, or None.
"""
# There must already be some accidentals, which are all sharps or all flats.
# Get the base pitch class for each note that has an accidental.
pitch_classes = [
self.clef.y_position_to_midi(position) %
constants.NUM_SEMITONES_PER_OCTAVE
for position in self._accidentals.keys()
]
# Determine the order of pitch classes (for either all sharps or all flats).
values = set(self._accidentals.values())
if len(values) == 1:
full_key_sig = _KEY_SIGNATURE_PITCH_CLASS_LIST[six.next(iter(values))]
else:
# Key signature is empty. Don't know whether to predict a sharp or a flat.
return None, None
if len(pitch_classes) == len(full_key_sig):
# No more accidentals to add.
return None, None
elif set(pitch_classes) == set(full_key_sig[:len(pitch_classes)]):
# Use the next pitch class in the list.
next_pitch_class = full_key_sig[len(pitch_classes)]
accidental = six.next(iter(values))
# The pitch class must match exactly one of the 7 y positions that are
# allowed for this key signature.
for y_position in _KEY_SIGNATURE_Y_POSITION_RANGES[self.clef.glyph,
accidental]:
if (self.clef.y_position_to_midi(y_position) %
constants.NUM_SEMITONES_PER_OCTAVE) == next_pitch_class:
return y_position, accidental
raise AssertionError('Failed to find the next accidental y position')
else:
# The current key signature is unrecognized.
return None, None
def get_type(self):
"""Returns whether this is a sharp, flat, or None (C major) signature."""
return (six.next(iter(self._accidentals.values()))
if self._accidentals else None)
def __len__(self):
"""Returns the number of accidentals in the key signature."""
return len(self._accidentals)
def _key_sig_pitch_classes(note_name, ascending_fifths):
first_pitch_class = (
librosa.note_to_midi(note_name + '0') %
constants.NUM_SEMITONES_PER_OCTAVE)
# Go through the circle of fifths in ascending or descending order.
step = 1 if ascending_fifths else -1
order = constants.CIRCLE_OF_FIFTHS[::step]
# Get the start index for the key signature.
first_pitch_class_ind = order.index(first_pitch_class)
return list(
itertools.islice(
# Create a cycle of the order. We may loop around, e.g. from F back to
# C.
itertools.cycle(order),
# Take the 7 pitch classes from the cycle.
first_pitch_class_ind,
first_pitch_class_ind + constants.NUM_NOTES_IN_DIATONIC_SCALE))
_KEY_SIGNATURE_PITCH_CLASS_LIST = {
# The sharp key signature starts with F#, and each subsequent note ascends
# by a fifth.
Glyph.SHARP:
_key_sig_pitch_classes('F', ascending_fifths=True),
# The flat key signature starts with Bb, and each subsequent note descends
# by a fifth.
Glyph.FLAT:
_key_sig_pitch_classes('B', ascending_fifths=False),
}
# Maps the clef and type of accidentals in the key signature to the range of y
# positions where the key signature is shown.
_KEY_SIGNATURE_Y_POSITION_RANGES = {
(Glyph.CLEF_TREBLE, Glyph.SHARP): range(-1, 6), # A#4 to G#5
(Glyph.CLEF_TREBLE, Glyph.FLAT): range(-3, 4), # Fb4 to Eb5
(Glyph.CLEF_BASS, Glyph.SHARP): range(-3, 4), # A#2 to G#3
(Glyph.CLEF_BASS, Glyph.FLAT): range(-5, 2), # Fb2 to Eb3
}
================================================
FILE: moonlight/score/elements/key_signature_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for key signature inference."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from moonlight.protobuf import musicscore_pb2
from moonlight.score.elements import clef
from moonlight.score.elements import key_signature
class KeySignatureTest(absltest.TestCase):
def testEmpty_noNextAccidental(self):
self.assertEqual(
key_signature.KeySignature(clef.TrebleClef()).get_next_accidental(),
(None, None))
def testGMajor(self):
sig = key_signature.KeySignature(clef.TrebleClef())
self.assertTrue(sig.try_put(+4, musicscore_pb2.Glyph.SHARP)) # F#
self.assertEqual(sig.get_next_accidental(),
(+1, musicscore_pb2.Glyph.SHARP)) # C#
def testGMajor_bassClef(self):
sig = key_signature.KeySignature(clef.BassClef())
self.assertTrue(sig.try_put(+2, musicscore_pb2.Glyph.SHARP)) # F#
self.assertEqual(sig.get_next_accidental(),
(-1, musicscore_pb2.Glyph.SHARP)) # C#
def testBMajor(self):
sig = key_signature.KeySignature(clef.TrebleClef())
self.assertTrue(sig.try_put(+4, musicscore_pb2.Glyph.SHARP)) # F#
self.assertTrue(sig.try_put(+1, musicscore_pb2.Glyph.SHARP)) # C#
self.assertTrue(sig.try_put(+5, musicscore_pb2.Glyph.SHARP)) # G#
self.assertEqual(sig.get_next_accidental(),
(+2, musicscore_pb2.Glyph.SHARP)) # D#
def testEFlatMajor(self):
sig = key_signature.KeySignature(clef.TrebleClef())
self.assertTrue(sig.try_put(0, musicscore_pb2.Glyph.FLAT)) # Bb
self.assertTrue(sig.try_put(+3, musicscore_pb2.Glyph.FLAT)) # Eb
self.assertTrue(sig.try_put(-1, musicscore_pb2.Glyph.FLAT)) # Ab
self.assertEqual(sig.get_next_accidental(),
(+2, musicscore_pb2.Glyph.FLAT)) # Db
def testEFlatMajor_bassClef(self):
sig = key_signature.KeySignature(clef.BassClef())
self.assertTrue(sig.try_put(-2, musicscore_pb2.Glyph.FLAT)) # Bb
self.assertTrue(sig.try_put(+1, musicscore_pb2.Glyph.FLAT)) # Eb
self.assertTrue(sig.try_put(-3, musicscore_pb2.Glyph.FLAT)) # Ab
self.assertEqual(sig.get_next_accidental(),
(0, musicscore_pb2.Glyph.FLAT)) # Db
def testCFlatMajor_noMoreAccidentals(self):
sig = key_signature.KeySignature(clef.TrebleClef())
self.assertTrue(sig.try_put(0, musicscore_pb2.Glyph.FLAT)) # Bb
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(+3, musicscore_pb2.Glyph.FLAT)) # Eb
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(-1, musicscore_pb2.Glyph.FLAT)) # Ab
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(+2, musicscore_pb2.Glyph.FLAT)) # Db
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(-2, musicscore_pb2.Glyph.FLAT)) # Gb
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(+1, musicscore_pb2.Glyph.FLAT)) # Cb
self.assertNotEqual(sig.get_next_accidental(), (None, None))
self.assertTrue(sig.try_put(-3, musicscore_pb2.Glyph.FLAT)) # Fb
# Already at Cb major, no more accidentals to add.
self.assertEqual(sig.get_next_accidental(), (None, None))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/score/measures.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Represents the measures of a staff system.
Converts bar x coordinates to a series of measures, with the x interval covered
by each measure.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.protobuf import musicscore_pb2
# sys.maxint overflows the int32 proto field.
MEASURE_MAX_X = 2**31 - 1
class Measures(object):
"""Represents the measures of a staff system."""
def __init__(self, staff_system):
self.bars = list(_get_bar_intervals(staff_system))
def size(self):
"""Returns the number of measures in the staff system.
Returns:
The number of measures.
"""
return len(self.bars)
def get_measure(self, glyph):
"""Gets the measure number of a `tensorflow.moonlight.Glyph`.
Args:
glyph: A `Glyph` message.
Returns:
The measure index, or -1 if it lies outside of the measures.
"""
for i, (start_bar, end_bar) in enumerate(self.bars):
if start_bar.x <= glyph.x < end_bar.x:
return i
return -1
def _get_bar_intervals(staff_system):
if not staff_system.bar:
# TODO(ringw): Store the image dimensions in the Page message, so that we
# can use the actual width as the end of the measure.
yield (musicscore_pb2.StaffSystem.Bar(x=0),
musicscore_pb2.StaffSystem.Bar(x=MEASURE_MAX_X))
elif len(staff_system.bar) == 1:
# Single barline is at the beginning of the staff.
yield staff_system.bar[0], musicscore_pb2.StaffSystem.Bar(x=MEASURE_MAX_X)
else:
for start, end in zip(staff_system.bar[:-1], staff_system.bar[1:]):
yield start, end
================================================
FILE: moonlight/score/reader.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reads Pages of glyphs and outputs a NoteSequence."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from moonlight.protobuf import musicscore_pb2
from moonlight.score import measures
from moonlight.score import state
from moonlight.score.elements import clef
from six import moves
# The expected y position for clefs.
TREBLE_CLEF_EXPECTED_Y = -2
BASS_CLEF_EXPECTED_Y = 2
# 4 beats to a whole note.
REST_DURATIONS_ = {
musicscore_pb2.Glyph.REST_QUARTER: 4 / 4,
musicscore_pb2.Glyph.REST_EIGHTH: 4 / 8,
musicscore_pb2.Glyph.REST_SIXTEENTH: 4 / 16,
}
class ScoreReader(object):
"""Reads a Score proto and interprets musical elements from the glyphs.
Given a Page containing glyphs, holds global state for the entire score, and
per-measure state (accidentals). Each glyph is added to the NoteSequence based
on the current state.
OMR is a work in progress. Voice detection is not yet implemented; the score
is assumed to be monophonic.
"""
def __init__(self):
self.time = 0.0
self.score_state = state.ScoreState()
def __call__(self, score):
"""Reads a `tensorflow.moonlight.Score` message.
Modifies the message in place to add detected musical elements.
Args:
score: A `tensorflow.moonlight.Score` message.
Returns:
The same Score object.
"""
for page in score.page:
self.read_page(page)
# Modifies the score in place.
return score
def read_page(self, page):
"""Reads a `tensorflow.moonlight.Page` message.
Modifies the page in place to add detected musical elements.
Args:
page: A `tensorflow.moonlight.Page` message.
Returns:
The same Page object.
"""
for system in page.system:
self.read_system(system)
return page
def read_system(self, system):
self.score_state.num_staves(len(system.staff))
system_measures = measures.Measures(system)
for measure_num in moves.xrange(system_measures.size()):
for staff, staff_state in zip(system.staff, self.score_state.staves):
for glyph in staff.glyph:
if system_measures.get_measure(glyph) == measure_num:
self._read_glyph(glyph, staff_state)
self.score_state.add_measure()
def _read_glyph(self, glyph, staff_state):
if glyph.type not in ScoreReader.GLYPH_HANDLERS_:
logging.warning('Handler not implemented: %s',
musicscore_pb2.Glyph.Type.Name(glyph.type))
return
ScoreReader.GLYPH_HANDLERS_[glyph.type](self, staff_state, glyph)
def _read_clef(self, staff_state, glyph):
"""Reads a clef glyph.
If the clef is at the expected y position, set the current clef.
Args:
staff_state: The state of the staff that the glyph is on.
glyph: A glyph of type CLEF_TREBLE or CLEF_BASS.
Raises:
ValueError: If glyph is an unexpected type.
"""
if glyph.type == musicscore_pb2.Glyph.CLEF_TREBLE:
if glyph.y_position == TREBLE_CLEF_EXPECTED_Y:
staff_state.set_clef(clef.TrebleClef())
elif glyph.type == musicscore_pb2.Glyph.CLEF_BASS:
if glyph.y_position == BASS_CLEF_EXPECTED_Y:
staff_state.set_clef(clef.BassClef())
else:
raise ValueError('Unknown clef of type: ' +
musicscore_pb2.Glyph.Type.Name(glyph.type))
def _read_note(self, staff_state, glyph):
staff_state.measure_state.on_read_notehead()
glyph.note.CopyFrom(staff_state.measure_state.get_note(glyph))
def _read_rest(self, staff_state, glyph):
staff_state.set_time(staff_state.get_time() + REST_DURATIONS_[glyph.type])
def _read_accidental(self, staff_state, glyph):
staff_state.measure_state.set_accidental(glyph.y_position, glyph.type)
def _no_op_handler(self, glyph):
pass
GLYPH_HANDLERS_ = {
musicscore_pb2.Glyph.NONE: _no_op_handler,
musicscore_pb2.Glyph.CLEF_TREBLE: _read_clef,
musicscore_pb2.Glyph.CLEF_BASS: _read_clef,
musicscore_pb2.Glyph.NOTEHEAD_FILLED: _read_note,
musicscore_pb2.Glyph.NOTEHEAD_EMPTY: _read_note,
musicscore_pb2.Glyph.NOTEHEAD_WHOLE: _read_note,
musicscore_pb2.Glyph.REST_QUARTER: _read_rest,
musicscore_pb2.Glyph.REST_EIGHTH: _read_rest,
musicscore_pb2.Glyph.REST_SIXTEENTH: _read_rest,
musicscore_pb2.Glyph.FLAT: _read_accidental,
musicscore_pb2.Glyph.NATURAL: _read_accidental,
musicscore_pb2.Glyph.SHARP: _read_accidental,
musicscore_pb2.Glyph.DOUBLE_SHARP: _read_accidental,
}
================================================
FILE: moonlight/score/reader_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the OMR score reader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import librosa
from protobuf import music_pb2
from moonlight import conversions
from moonlight.protobuf import musicscore_pb2
from moonlight.score import reader
# pylint: disable=invalid-name
Glyph = musicscore_pb2.Glyph
Note = music_pb2.NoteSequence.Note
Point = musicscore_pb2.Point
class ReaderTest(absltest.TestCase):
def testTreble_simple(self):
staff = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=0),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[musicscore_pb2.StaffSystem(
staff=[staff])])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
Note(pitch=librosa.note_to_midi('B4'), start_time=0, end_time=1)
]))
def testBass_simple(self):
staff = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_BASS,
x=1,
y_position=reader.BASS_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=0),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[musicscore_pb2.StaffSystem(
staff=[staff])])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
Note(pitch=librosa.note_to_midi('D3'), start_time=0, end_time=1)
]))
def testTreble_accidentals(self):
staff_1 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-6),
Glyph(type=Glyph.FLAT, x=16, y_position=-4),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=20, y_position=-4),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=30, y_position=-2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=40, y_position=-4),
])
staff_2 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=150), Point(x=100, y=150)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-6),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=20, y_position=-4),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=30, y_position=-2),
Glyph(type=Glyph.SHARP, x=35, y_position=-2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=40, y_position=-2),
Glyph(type=Glyph.NATURAL, x=45, y_position=-2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=-2),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(staff=[staff_1]),
musicscore_pb2.StaffSystem(staff=[staff_2])
])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
# First staff.
Note(pitch=librosa.note_to_midi('C4'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('Eb4'), start_time=1, end_time=2),
Note(pitch=librosa.note_to_midi('G4'), start_time=2, end_time=3),
Note(pitch=librosa.note_to_midi('Eb4'), start_time=3, end_time=4),
# Second staff.
Note(pitch=librosa.note_to_midi('C4'), start_time=4, end_time=5),
Note(pitch=librosa.note_to_midi('E4'), start_time=5, end_time=6),
Note(pitch=librosa.note_to_midi('G4'), start_time=6, end_time=7),
Note(pitch=librosa.note_to_midi('G#4'), start_time=7, end_time=8),
Note(pitch=librosa.note_to_midi('G4'), start_time=8, end_time=9),
]))
def testChords(self):
stem_1 = musicscore_pb2.LineSegment(
start=Point(x=20, y=10), end=Point(x=20, y=70))
stem_2 = musicscore_pb2.LineSegment(
start=Point(x=50, y=10), end=Point(x=50, y=70))
staff = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
# Chord of 2 notes.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-4, stem=stem_1),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-1, stem=stem_1),
# Note not attached to a stem.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=30, y_position=3),
# Chord of 3 notes.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=40, y_position=0, stem=stem_2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=60, y_position=2, stem=stem_2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=60, y_position=4, stem=stem_2),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[musicscore_pb2.StaffSystem(
staff=[staff])])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
# First chord.
Note(pitch=librosa.note_to_midi('E4'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('A4'), start_time=0, end_time=1),
# Note without a stem.
Note(pitch=librosa.note_to_midi('E5'), start_time=1, end_time=2),
# Second chord.
Note(pitch=librosa.note_to_midi('B4'), start_time=2, end_time=3),
Note(pitch=librosa.note_to_midi('D5'), start_time=2, end_time=3),
Note(pitch=librosa.note_to_midi('F5'), start_time=2, end_time=3),
]))
def testBeams(self):
beam_1 = musicscore_pb2.LineSegment(
start=Point(x=10, y=20), end=Point(x=40, y=20))
beam_2 = musicscore_pb2.LineSegment(
start=Point(x=70, y=40), end=Point(x=90, y=40))
beam_3 = musicscore_pb2.LineSegment(
start=Point(x=70, y=60), end=Point(x=90, y=60))
staff = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
# 2 eighth notes.
Glyph(
type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-4, beam=[beam_1]),
Glyph(
type=Glyph.NOTEHEAD_FILLED, x=40, y_position=-1, beam=[beam_1]),
# 1 quarter note.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=0),
# 2 sixteenth notes.
Glyph(
type=Glyph.NOTEHEAD_FILLED,
x=60,
y_position=-2,
beam=[beam_2, beam_3]),
Glyph(
type=Glyph.NOTEHEAD_FILLED,
x=90,
y_position=2,
beam=[beam_2, beam_3]),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[musicscore_pb2.StaffSystem(
staff=[staff])])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
Note(pitch=librosa.note_to_midi('E4'), start_time=0, end_time=0.5),
Note(pitch=librosa.note_to_midi('A4'), start_time=0.5, end_time=1),
Note(pitch=librosa.note_to_midi('B4'), start_time=1, end_time=2),
Note(pitch=librosa.note_to_midi('G4'), start_time=2, end_time=2.25),
Note(
pitch=librosa.note_to_midi('D5'), start_time=2.25,
end_time=2.5),
]))
def testAllNoteheadTypes(self):
staff = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-6),
Glyph(type=Glyph.NOTEHEAD_EMPTY, x=10, y_position=-6),
Glyph(type=Glyph.NOTEHEAD_WHOLE, x=10, y_position=-6),
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[musicscore_pb2.StaffSystem(
staff=[staff])])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
Note(pitch=librosa.note_to_midi('C4'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('C4'), start_time=1, end_time=3),
Note(pitch=librosa.note_to_midi('C4'), start_time=3, end_time=7),
]))
def testStaffSystems(self):
# 2 staff systems on separate pages, each with 2 staves, and no bars.
system_1_staff_1 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=100, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=-6),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=-2),
])
system_1_staff_2 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=150), Point(x=100, y=150)],
glyph=[
Glyph(
type=Glyph.CLEF_BASS,
x=2,
y_position=reader.BASS_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=0),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=40, y_position=2),
# Played after the second note in the first staff, although it is to
# the left of it.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=45, y_position=4),
])
system_2_staff_1 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=250), Point(x=100, y=250)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
Glyph(type=Glyph.REST_QUARTER, x=20, y_position=0),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=-2),
])
system_2_staff_2 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=250), Point(x=100, y=250)],
glyph=[
Glyph(
type=Glyph.CLEF_BASS,
x=2,
y_position=reader.BASS_CLEF_EXPECTED_Y),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=10, y_position=0),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=40, y_position=2),
])
notes = conversions.score_to_notesequence(reader.ScoreReader()(
musicscore_pb2.Score(page=[
musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(
staff=[system_1_staff_1, system_1_staff_2]),
]),
musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(
staff=[system_2_staff_1, system_2_staff_2]),
]),
]),))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
# System 1, staff 1.
Note(pitch=librosa.note_to_midi('C4'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('G4'), start_time=1, end_time=2),
# System 1, staff 2.
Note(pitch=librosa.note_to_midi('D3'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('F3'), start_time=1, end_time=2),
Note(pitch=librosa.note_to_midi('A3'), start_time=2, end_time=3),
# System 2, staff 1.
# Quarter rest.
Note(pitch=librosa.note_to_midi('G4'), start_time=4, end_time=5),
# System 2, staff 2.
Note(pitch=librosa.note_to_midi('D3'), start_time=3, end_time=4),
Note(pitch=librosa.note_to_midi('F3'), start_time=4, end_time=5),
]))
def testMeasures(self):
# 2 staves in the same staff system with multiple bars.
staff_1 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=50), Point(x=300, y=50)],
glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=1,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
# Key signature.
Glyph(type=Glyph.SHARP, x=10, y_position=+4),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=20, y_position=-2),
# Accidental.
Glyph(type=Glyph.FLAT, x=40, y_position=-1),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=-1),
# Second bar.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=120, y_position=0),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=180, y_position=+4),
# Third bar.
# Accidental not propagated to this note.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=220, y_position=-1),
])
staff_2 = musicscore_pb2.Staff(
staffline_distance=10,
center_line=[Point(x=0, y=150), Point(x=300, y=150)],
glyph=[
Glyph(
type=Glyph.CLEF_BASS,
x=1,
y_position=reader.BASS_CLEF_EXPECTED_Y),
# Key signature.
Glyph(type=Glyph.FLAT, x=15, y_position=-2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=20, y_position=-2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=50, y_position=+2),
# Second bar.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=150, y_position=-2),
# Third bar.
Glyph(type=Glyph.REST_QUARTER, x=220, y_position=0),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=280, y_position=-2),
])
staff_system = musicscore_pb2.StaffSystem(
staff=[staff_1, staff_2],
bar=[_bar(0), _bar(100), _bar(200),
_bar(300)])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[staff_system])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
# Staff 1, bar 1.
Note(pitch=librosa.note_to_midi('G4'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('Ab4'), start_time=1, end_time=2),
# Staff 1, bar 2.
Note(pitch=librosa.note_to_midi('B4'), start_time=2, end_time=3),
Note(pitch=librosa.note_to_midi('F#5'), start_time=3, end_time=4),
# Staff 1, bar 3.
Note(pitch=librosa.note_to_midi('A4'), start_time=4, end_time=5),
# Staff 2, bar 1.
Note(pitch=librosa.note_to_midi('Bb2'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('F3'), start_time=1, end_time=2),
# Staff 2, bar 2.
Note(pitch=librosa.note_to_midi('Bb2'), start_time=2, end_time=3),
# Staff 2, bar 3.
Note(pitch=librosa.note_to_midi('Bb2'), start_time=5, end_time=6),
]))
def testKeySignatures(self):
# One staff per system, two systems.
staff_1 = musicscore_pb2.Staff(glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=5,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
# D major key signature.
Glyph(type=Glyph.SHARP, x=15, y_position=+4),
Glyph(type=Glyph.SHARP, x=25, y_position=+1),
# Accidental which cannot be interpreted as part of the key
# signature.
Glyph(type=Glyph.SHARP, x=35, y_position=+2),
Glyph(type=Glyph.NOTEHEAD_FILLED, x=45, y_position=+2), # D#5
Glyph(type=Glyph.NOTEHEAD_EMPTY, x=55, y_position=+1), # C#5
Glyph(type=Glyph.NOTEHEAD_FILLED, x=65, y_position=-3), # F#4
# New measure. The key signature should be retained.
Glyph(type=Glyph.NOTEHEAD_EMPTY, x=105, y_position=-3), # F#4
Glyph(type=Glyph.NOTEHEAD_FILLED, x=125, y_position=+1), # C#5
# Accidental is not retained.
Glyph(type=Glyph.NOTEHEAD_FILLED, x=145, y_position=+2), # D5
])
staff_2 = musicscore_pb2.Staff(glyph=[
Glyph(
type=Glyph.CLEF_TREBLE,
x=5,
y_position=reader.TREBLE_CLEF_EXPECTED_Y),
# No key signature on this line. No accidentals.
Glyph(type=Glyph.NOTEHEAD_EMPTY, x=25, y_position=-3), # F4
Glyph(type=Glyph.NOTEHEAD_EMPTY, x=45, y_position=+1), # C5
])
notes = conversions.page_to_notesequence(reader.ScoreReader().read_page(
musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(
staff=[staff_1], bar=[_bar(0), _bar(100),
_bar(200)]),
musicscore_pb2.StaffSystem(staff=[staff_2]),
])))
self.assertEqual(
notes,
music_pb2.NoteSequence(notes=[
# First measure.
Note(pitch=librosa.note_to_midi('D#5'), start_time=0, end_time=1),
Note(pitch=librosa.note_to_midi('C#5'), start_time=1, end_time=3),
Note(pitch=librosa.note_to_midi('F#4'), start_time=3, end_time=4),
# Second measure.
Note(pitch=librosa.note_to_midi('F#4'), start_time=4, end_time=6),
Note(pitch=librosa.note_to_midi('C#5'), start_time=6, end_time=7),
Note(pitch=librosa.note_to_midi('D5'), start_time=7, end_time=8),
# Third measure on a new line, with no key signature.
Note(pitch=librosa.note_to_midi('F4'), start_time=8, end_time=10),
Note(pitch=librosa.note_to_midi('C5'), start_time=10, end_time=12),
]))
def _bar(x):
return musicscore_pb2.StaffSystem.Bar(
x=x, type=musicscore_pb2.StaffSystem.Bar.STANDARD_BAR)
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/score/state/BUILD
================================================
# Description:
# The music score state. Holds information such as the key signature and current time in the score,
# which is used for reading notes from the score.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "state",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":measure",
":staff",
# six dep
],
)
py_library(
name = "measure",
srcs = ["measure.py"],
srcs_version = "PY2AND3",
deps = [
# enum34 dep
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/score/elements:key_signature",
"@magenta//protobuf:music_py_pb2",
],
)
py_library(
name = "staff",
srcs = ["staff.py"],
srcs_version = "PY2AND3",
deps = [
":measure",
"//moonlight/score/elements:clef",
],
)
================================================
FILE: moonlight/score/state/__init__.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The global state for the entire score."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.score.state import staff as staff_state
from six import moves
class ScoreState(object):
"""The global state for the entire score.
Represents the state of the score across multiple staff systems. The staff
system state does not change on a new line, unless the number of staves
changes. `num_staves` should be called for every new staff system, to reset
the staff state properly.
Attributes:
staves: A list of StaffState objects, representing each staff in the current
staff system.
"""
def __init__(self):
self.staves = []
def num_staves(self, num_staves):
"""Updates the score to have the given number of staves.
If `num_staves` matches the current `len(self.staves)`, copies the persisted
state from the previous staves to the new staff system. Otherwise,
discards any current staves and constructs `num_staves` new staves.
Args:
num_staves: The number of staves for the current staff system.
"""
time = self.add_measure()
if len(self.staves) != self.num_staves:
self.staves = [
staff_state.StaffState(time) for _ in moves.xrange(num_staves)
]
else:
self.staves = [staff.new_staff(time) for staff in self.staves]
def add_measure(self):
"""Adds a new measure for the current staff system.
Called on every bar. Updates each staff.
Returns:
The start time of the new measure, which is the max of the current time of
each current staff.
"""
time = (
max([staff.get_time() for staff in self.staves]) if self.staves else 0)
for staff in self.staves:
staff.add_measure(time)
return time
================================================
FILE: moonlight/score/state/measure.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The score state which is not persisted between measures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import enum
from protobuf import music_pb2
from moonlight.protobuf import musicscore_pb2
from moonlight.score.elements import key_signature as key_signature_module
ACCIDENTAL_PITCH_SHIFT_ = {
# TODO(ringw): Detect 2 adjacent flats as a double flat.
musicscore_pb2.Glyph.FLAT:
-1,
musicscore_pb2.Glyph.NATURAL:
0,
musicscore_pb2.Glyph.NONE:
0,
musicscore_pb2.Glyph.SHARP:
+1,
musicscore_pb2.Glyph.DOUBLE_SHARP:
+2,
}
class _KeySignatureState(enum.Enum):
KEY_SIGNATURE = 1
ACCIDENTALS = 2
class MeasureState(object):
"""State of a single measure of a staff.
Attributes:
clef: The current clef.
key_signature: The current `KeySignature`.
chords: A map from stem (tuple `((x0, y0), (x1, y1))`) to the first note
that was read and is attached to the stem. Subsequent notes attached to
the same stem will read their start and end time from the first note.
time: The current time in the measure. Absolute time relative to the start
of the score. float.
"""
def __init__(self, start_time, clef, key_signature=None):
"""Initializes a new measure.
Args:
start_time: The start time (in quarter notes) of the measure.
clef: A `Clef`.
key_signature: The previously detected key signature (optional). If
present, do not detect a key signature in this measure. This should be
taken from the previously measure on this staff if this is not the first
measure. It should not be propagated from one staff to the next, because
we expect the key signature to be repeated on each staff and we will
re-detect it.
"""
self.time = start_time
self.clef = clef
self.key_signature = (
key_signature or key_signature_module.KeySignature(clef))
self._accidentals = key_signature_module.Accidentals(clef)
self._key_signature_state = (
_KeySignatureState.ACCIDENTALS
if key_signature else _KeySignatureState.KEY_SIGNATURE)
self.chords = {}
def new_measure(self, start_time):
"""Constructs a new MeasureState for the next measure.
Args:
start_time: The start time of the new measure.
Returns:
A new MeasureState object.
"""
return MeasureState(
start_time,
clef=self.clef,
key_signature=copy.deepcopy(self.key_signature))
def set_accidental(self, y_position, accidental):
"""Adds a glyph to the key signature or accidentals.
Args:
y_position: The position of the accidental.
accidental: The accidental value.
"""
if self._key_signature_state == _KeySignatureState.KEY_SIGNATURE:
if self.key_signature.try_put(y_position, accidental):
return
self._key_signature_state = _KeySignatureState.ACCIDENTALS
self._accidentals.put(y_position, accidental)
def get_note(self, glyph):
"""Converts a Glyph to a Note.
Gets the note timing from an existing chord if available, or increments the
current measure time otherwise.
Args:
glyph: A Glyph message. Type must be one of NOTEHEAD_*.
Returns:
A Note message.
"""
accidental = self._accidentals.get_accidental_for_position(glyph.y_position)
if accidental == musicscore_pb2.Glyph.NONE:
accidental = self.key_signature.get_accidental_for_position(
glyph.y_position)
pitch = (
self.clef.y_position_to_midi(glyph.y_position) +
ACCIDENTAL_PITCH_SHIFT_[accidental])
first_note_in_chord = None
if glyph.HasField('stem'):
# Try to get the timing from another note in the same chord.
stem = ((glyph.stem.start.x, glyph.stem.start.y), (glyph.stem.end.x,
glyph.stem.end.y))
if stem in self.chords:
first_note_in_chord = self.chords[stem]
else:
stem = None
if first_note_in_chord:
start_time, end_time = (first_note_in_chord.start_time,
first_note_in_chord.end_time)
else:
# TODO(ringw): Check all note durations, not just the first seen in a
# chord, and use the median detected duration.
duration = _get_note_duration(glyph)
start_time, end_time = self.time, self.time + duration
self.time += duration
note = music_pb2.NoteSequence.Note(
pitch=pitch, start_time=start_time, end_time=end_time)
if stem:
self.chords[stem] = note
return note
def set_clef(self, clef):
"""Sets the clef, and resets the key signature if necessary."""
if clef != self.clef:
self._key_signature_state = _KeySignatureState.KEY_SIGNATURE
self.key_signature = key_signature_module.KeySignature(clef)
self._accidentals = key_signature_module.Accidentals(clef)
self.clef = clef
def on_read_notehead(self):
"""Called after a notehead has been read.
The key signature should occur before any noteheads in the measure. This
causes subsequent accidental glyphs to be read as accidentals, and not part
of the key signature.
"""
self._key_signature_state = _KeySignatureState.ACCIDENTALS
def _get_note_duration(note):
"""Determines the duration of a notehead glyph.
This depends on the glyph type, beams (which each halve the duration), and
dots (which each add a fractional duration). In the future, notes may be
recognized as a tuplet, which will result in a Fraction duration. For now, the
duration is a float, because the denominator is always a sum of powers of two.
Args:
note: A `Glyph` of a notehead type.
Returns:
The float duration of the note, in quarter notes.
Raises:
ValueError: If `note` is not a notehead type.
"""
if note.type == musicscore_pb2.Glyph.NOTEHEAD_FILLED:
# Quarter note: 2.0 ** 0 == 1
# Each beam halves the note duration.
duration = 2.0**-len(note.beam)
elif note.type == musicscore_pb2.Glyph.NOTEHEAD_EMPTY:
duration = 2.0
elif note.type == musicscore_pb2.Glyph.NOTEHEAD_WHOLE:
duration = 4.0
else:
raise ValueError('Expected a notehead, got: %s' % note)
# The first dot adds half the original duration, and further dots add half the
# value added by the previous dot.
dot_value = duration / 2.
for _ in note.dot:
duration += dot_value
dot_value /= 2.
return duration
================================================
FILE: moonlight/score/state/staff.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The state for a single staff of a staff system."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.score.elements import clef as clef_module
from moonlight.score.state import measure
class StaffState(object):
"""The state for a single staff of a staff system.
Holds the current measure of the staff, which has per-measure state (e.g.
accidentals). Other state is copied to a new measure at each barline, and
copied to a new StaffState representing a new line when `new_staff` is called.
"""
def __init__(self, start_time, clef=None):
clef = clef or clef_module.TrebleClef()
self.measure_state = measure.MeasureState(start_time, clef=clef)
def add_measure(self, start_time):
"""Updates `measure_state` for a new measure.
Args:
start_time: The start time of the new measure. Copies state which is
persisted between measures, and initializes other state to the defaults
for the new measure.
"""
self.measure_state = self.measure_state.new_measure(start_time)
def new_staff(self, start_time):
"""Copies the StaffState to a new staff on a new line.
Args:
start_time: Start time of the first measure of the new staff.
Returns:
A new StaffState instance.
"""
# Don't persist the key signature between staves, since we expect it to be
# at the start of each line.
return StaffState(start_time, clef=self.measure_state.clef)
def get_key_signature(self):
"""Returns the key signature at the current point in time."""
return self.measure_state.key_signature
def get_time(self):
"""Returns the current time."""
return self.measure_state.time
def set_time(self, time):
"""Updates the current time of the current measure.
Args:
time: A floating-point time. Updates `self.measure_state`.
"""
self.measure_state.time = time
def set_clef(self, clef):
"""Updates the clef.
Args:
clef: A TrebleClef or BassClef. Updates `self.measure_state`.
"""
self.measure_state.set_clef(clef)
================================================
FILE: moonlight/score_processors.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processors that need to visit each page of the score in one pass.
These are intended for detecting musical elements, where musical context may
span staff systems and pages (e.g. the time signature). Musical elements (e.g.
notes) are added to the `Score` message directly.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.score import reader
def create_processors():
yield reader.ScoreReader()
def process(score):
"""Processes a Score.
Detects notes in the Score, and returns the Score in place.
Args:
score: A `Score` message.
Returns:
A `Score` message with `Note`s added to the `Glyph`s where applicable.
"""
for processor in create_processors():
score = processor(score)
return score
================================================
FILE: moonlight/scripts/imslp_pdfs_to_pngs.sh
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
# Extracts PNGs from the IMSLP backup. The output PNGs are suitable for training
# and running OMR. See IMSLP for the latest details on the backup:
# http://imslp.org/wiki/IMSLP:Backups
INPUT_DIR="$1"
OUTPUT_DIR="$2"
if ! [[ -d "$INPUT_DIR" ]]; then
echo "First argument must be a directory" > /dev/stderr
exit -1
fi
if ! [[ -d "$OUTPUT_DIR" ]]; then
mkdir -v "$OUTPUT_DIR"
fi
if ! [[ -x "$(which pdfimages)" ]]; then
echo "pdfimages is required. Please install poppler-utils." > /dev/stderr
exit -1
fi
if ! [[ -x "$(which parallel)" ]]; then
echo "GNU parallel is required. Please install parallel." > /dev/stderr
exit -1
fi
if ! [[ -x "$(which convert)" ]]; then
echo "'convert' is required. Please install imagemagick." > /dev/stderr
exit -1
fi
# For each pdf...
find "$INPUT_DIR" -name "IMSLP*.pdf" | \
# Convert to "IMSLPnnnnn-nnn.ppm" or ".pgm" images in $OUTPUT_DIR.
perl -ne 'chomp; /(IMSLP[0-9]+)/; print qq(pdfimages "$_" "'"$OUTPUT_DIR"'"/$1\n)' | \
# Run all emitted commands in parallel.
parallel -v
# Convert extracted "pbm", "pgm", and "ppm" images to PNG.
(for file in "$OUTPUT_DIR"/*.p[bgp]m; do
echo "convert '$file' '${file%.*}.png' && rm -v '$file'"
done) | parallel -v
================================================
FILE: moonlight/staves/BUILD
================================================
# Description:
# Staff detection and removal routines.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "staves",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":hough",
":projection",
],
)
py_library(
name = "base",
srcs = ["base.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/util:memoize",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "filter",
srcs = ["filter.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/util:segments",
# tensorflow dep
],
)
py_library(
name = "hough",
srcs = ["hough.py"],
srcs_version = "PY2AND3",
deps = [
":base",
":filter",
":staffline_distance",
"//moonlight/util:memoize",
"//moonlight/vision:hough",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "projection",
srcs = ["projection.py"],
srcs_version = "PY2AND3",
deps = [
":base",
"//moonlight/util:memoize",
# numpy dep
# scipy dep
# tensorflow dep
],
)
py_test(
name = "detectors_test",
srcs = ["detectors_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":staffline_distance",
":staves",
":testing",
# disable_tf2
"//moonlight:image",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "staff_processor",
srcs = ["staff_processor.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "staff_processor_test",
srcs = ["staff_processor_test.py"],
srcs_version = "PY2AND3",
deps = [
":staff_processor",
":testing",
# disable_tf2
# absl/testing dep
"//moonlight:engine",
"//moonlight/glyphs:testing",
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/structure",
# numpy dep
],
)
py_library(
name = "staffline_extractor",
srcs = ["staffline_extractor.py"],
srcs_version = "PY2AND3",
deps = [
":removal",
":staves",
# enum34 dep
"//moonlight:image",
# six dep
# tensorflow dep
],
)
py_test(
name = "staffline_extractor_test",
size = "small",
timeout = "moderate",
srcs = ["staffline_extractor_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":staffline_extractor",
":staves",
# disable_tf2
# numpy dep
# tensorflow dep
],
)
py_library(
name = "staffline_distance",
srcs = ["staffline_distance.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/util:run_length",
"//moonlight/util:segments",
# tensorflow dep
],
)
py_test(
name = "staffline_distance_test",
srcs = ["staffline_distance_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":staffline_distance",
# disable_tf2
"//moonlight:image",
# tensorflow dep
],
)
py_library(
name = "removal",
srcs = ["removal.py"],
deps = [
"//moonlight/util:memoize",
"//moonlight/util:segments",
# tensorflow dep
],
)
py_test(
name = "removal_test",
srcs = ["removal_test.py"],
data = ["//moonlight/testdata:images"],
deps = [
":removal",
":staffline_distance",
":staves",
# disable_tf2
"//moonlight:image",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "testing",
testonly = 1,
srcs = ["testing.py"],
srcs_version = "PY2AND3",
deps = [":base"],
)
================================================
FILE: moonlight/staves/__init__.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staff detection.
Holds the staff detector classes that can be used as part of an OMR pipeline.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.staves import hough
from moonlight.staves import projection
# Alias the staff detectors to access them directly from the staves module.
# pylint: disable=invalid-name
FilteredHoughStaffDetector = hough.FilteredHoughStaffDetector
ProjectionStaffDetector = projection.ProjectionStaffDetector
# The default staff detector that should be used in production.
StaffDetector = FilteredHoughStaffDetector
================================================
FILE: moonlight/staves/base.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the base class for all staff detectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import numpy as np
import tensorflow as tf
from moonlight.util import memoize
class BaseStaffDetector(object):
"""Base for a routine that returns staves in a music score.
Attributes of concrete subclasses:
staves
staffline_distance
staffline_thickness
"""
__metaclass__ = abc.ABCMeta
def __init__(self, image=None):
"""Creates a staff detector for the given music score image.
Args:
image: The music score image. If none, sets self.image to a placeholder.
"""
if image is None:
self.image = tf.placeholder(tf.uint8, shape=(None, None))
else:
self.image = tf.convert_to_tensor(image, tf.uint8)
@property
def data(self):
"""Returns the list of staff detection tensors to be computed.
Returns:
A list of Tensors.
"""
return [
self.staves, self.staffline_distance, self.staffline_thickness,
self.staves_interpolated_y
]
@property
@memoize.MemoizedFunction
def staves_interpolated_y(self):
"""Interpolates the center line y coordinate for each staff.
Calculates the staff center y for each x coordinate from 0 to `width - 1`.
Returns:
A tensor of shape (num_staves, width).
"""
image_shape = tf.shape(self.image)
def _get_staff_center_line_y(staff):
"""Interpolates the y position for the staff.
For x values in the interval [0, image_shape[1]), calculate the y
position. The y position past either end of the staff line is assumed to
be the same as at the endpoint.
Args:
staff: The sequence of (x, y) coordinates for the staff center line.
int32 tensor of shape (num_points, 2).
Returns:
The array of y position values.
"""
staff = tf.convert_to_tensor(staff, dtype=tf.int32)
input_validation = [
tf.Assert(
tf.greater_equal(tf.shape(staff)[0], 2),
[staff, tf.shape(staff)],
name="at_least_2_points"),
tf.Assert(
tf.equal(tf.shape(staff)[1], 2), [staff, tf.shape(staff)],
name="x_and_y"),
tf.Assert(
tf.greater_equal(staff[0, 0], 0), [image_shape, staff],
name="staff_x_positive"),
tf.Assert(
tf.less(staff[-1, 0], image_shape[1]), [image_shape, staff],
name="staff_x_ends_before_end_of_image"),
]
# Validate the input before the main body.
with tf.control_dependencies(input_validation):
num_points = tf.shape(staff)[0]
# The segments cover left of the staff, each consecutive pair of points,
# and right of the staff.
num_segments = num_points + 1
def loop_body(i, ys_array):
"""Executes on each iteration of the TF while loop."""
# Interpolate the y coordinates of the line between staff points i - 1
# and i (i >= 1). The y coordinates correspond to x in the interval
# [staff[i - 1, 0], staff[i, 0]).
x0 = staff[i - 1, 0]
y0 = staff[i - 1, 1]
x1 = staff[i, 0]
y1 = staff[i, 1]
segment_ys = (
tf.cast(
tf.round(
tf.cast(y1 - y0, tf.float32) *
tf.linspace(0., 1., x1 - x0 + 1)[:-1]), tf.int32) + y0)
# Update the loop variables. Increment i, and write the current segment
# ys to the array.
return i + 1, ys_array.write(i, segment_ys)
# Run a while loop to generate line segments between consecutive staff
# points.
all_ys_array = tf.TensorArray(
tf.int32, infer_shape=False, size=num_segments)
# The first segment covers [0, staff[0, 0]) (may be empty).
all_ys_array = all_ys_array.write(0, tf.tile([staff[0, 1]],
[staff[0, 0]]))
# Write the segments in the interval [1, num_segments - 2].
unused_i, all_ys_array = tf.while_loop(
lambda i, unused_ys: i < num_segments - 1, loop_body,
[1, all_ys_array])
# The last segment covers [staff[-1, 0], width) (may be empty).
all_ys_array = all_ys_array.write(
num_segments - 1,
tf.tile([staff[-1, 1]], [image_shape[1] - staff[-1, 0]]))
all_ys = all_ys_array.concat()
output_validation = [
tf.Assert(
tf.equal(tf.shape(all_ys)[0], image_shape[1]),
[tf.shape(all_ys), image_shape]),
]
# Validate the output before returning. We need an actual op inside the
# with statement (tf.identity).
with tf.control_dependencies(output_validation):
return tf.identity(all_ys)
# The map_fn will fail if there are no staves. In that case, return an empty
# array with the correct width.
return tf.cond(
tf.shape(self.staves)[0] > 0,
lambda: tf.map_fn(_get_staff_center_line_y, self.staves),
lambda: tf.zeros([0, image_shape[1]], tf.int32))
def compute(self, session=None, feed_dict=None):
"""Runs staff detection.
Args:
session: The session to use instead of the default session.
feed_dict: The feed dict for the TensorFlow graph.
Returns:
A `ComputedStaves` holding NumPy arrays for the staves.
"""
if session is None:
session = tf.get_default_session()
return ComputedStaves(*session.run(self.data, feed_dict=feed_dict))
class ComputedStaves(BaseStaffDetector):
"""Computed staves holder.
The result of `BaseStaffDetector.compute()`. Holds NumPy arrays with the
result of staff detection.
"""
def __init__(self, staves, staffline_distance, staffline_thickness,
staves_interpolated_y):
super(ComputedStaves, self).__init__()
# TODO(ringw): Add a way to ensure the inputs are array-like and not
# Tensor objects.
self.staves = np.asarray(staves)
self.staffline_distance = np.asarray(staffline_distance)
self.staffline_thickness = np.asarray(staffline_thickness)
self.staves_interpolated_y_arr = np.asarray(staves_interpolated_y)
@property
def staves_interpolated_y(self):
return self.staves_interpolated_y_arr
def compute(self, session=None, feed_dict=None):
"""Returns the already computed staves.
Args:
session: TensorFlow session; ignored.
feed_dict: TensorFlow feed dict; ignored.
Returns:
self.
"""
return self
================================================
FILE: moonlight/staves/detectors_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the staff detectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import numpy as np
import tensorflow as tf
from moonlight import image as omr_image
from moonlight import staves
from moonlight.staves import staffline_distance
from moonlight.staves import testing
class StaffDetectorsTest(tf.test.TestCase):
def setUp(self):
# The normal _MIN_STAFFLINE_DISTANCE_SCORE is too large for the small images
# used in unit tests.
self.old_min_staffline_distance_score = (
staffline_distance._MIN_STAFFLINE_DISTANCE_SCORE)
staffline_distance._MIN_STAFFLINE_DISTANCE_SCORE = 10
def tearDown(self):
staffline_distance._MIN_STAFFLINE_DISTANCE_SCORE = (
self.old_min_staffline_distance_score)
def test_single_staff(self):
blank_row = [255] * 50
staff_row = [255] * 4 + [0] * 42 + [255] * 4
# Create an image with 5 staff lines, with a slightly noisy staffline
# thickness and distance.
image = np.asarray([blank_row] * 25 + [staff_row] * 2 + [blank_row] * 8 +
[staff_row] * 3 + [blank_row] * 8 + [staff_row] * 3 +
[blank_row] * 9 + [staff_row] * 2 + [blank_row] * 8 +
[staff_row] * 2 + [blank_row] * 25, np.uint8)
for detector in self.generate_staff_detectors(image):
with self.test_session() as sess:
staves_arr, staffline_distances, staffline_thickness = sess.run(
(detector.staves, detector.staffline_distance,
detector.staffline_thickness))
expected_y = 25 + 2 + 8 + 3 + 8 + 1 # y coordinate of the center line
self.assertEqual(
staves_arr.shape[0], 1,
'Expected single staff from detector %s. Got: %d' %
(detector, staves_arr.shape[0]))
self.assertAlmostEqual(
np.mean(staves_arr[0, :, 1]), # average y position
expected_y,
delta=2.0)
self.assertAlmostEqual(staffline_distances[0], 11, delta=1.0)
self.assertLessEqual(staffline_thickness, 3)
def test_corpus_image(self):
# Test only the default staff detector (because projection won't detect all
# staves).
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
image_t = omr_image.decode_music_score_png(tf.read_file(filename))
detector = staves.StaffDetector(image_t)
with self.test_session() as sess:
staves_arr, staffline_distances = sess.run(
[detector.staves, detector.staffline_distance])
self.assertAllClose(
np.mean(staves_arr[:, :, 1], axis=1), # average y position
[413, 603, 848, 1040, 1286, 1476, 1724, 1915, 2162, 2354, 2604, 2795],
atol=5)
self.assertAllEqual(staffline_distances, [16] * 12)
def test_staves_interpolated_y(self):
# Test staff center line interpolation.
# The sequence of (x, y) points always starts at x = 0 and ends at
# x = width - 1.
staff = tf.constant(
np.array([[[0, 10], [5, 5], [11, 0], [15, 10], [20, 20], [23, 49]]],
np.int32))
with self.test_session():
line_y = testing.FakeStaves(tf.zeros([50, 24]),
staff).staves_interpolated_y[0].eval()
self.assertEquals(
list(line_y), [
10, 9, 8, 7, 6, 5, 4, 3, 3, 2, 1, 0, 2, 5, 8, 10, 12, 14, 16, 18,
20, 30, 39, 49
])
def test_staves_interpolated_y_empty(self):
with self.test_session():
self.assertAllEqual(
testing.FakeStaves(tf.zeros([50, 25]), tf.zeros(
[0, 2, 2], np.int32)).staves_interpolated_y.eval().shape, [0, 25])
def test_staves_interpolated_y_staves_dont_extend_to_edge(self):
staff = tf.constant(np.array([[[5, 10], [12, 8]]], np.int32))
with self.test_session():
# The y values should extend past the endpoints to the edge of the image,
# and should be equal to the y value at the nearest endpoint.
self.assertAllEqual(
testing.FakeStaves(tf.zeros([50, 15]),
staff).staves_interpolated_y[0].eval(),
[10, 10, 10, 10, 10, 10, 10, 9, 9, 9, 9, 8, 8, 8, 8])
def generate_staff_detectors(self, image):
yield staves.ProjectionStaffDetector(image)
yield staves.FilteredHoughStaffDetector(image)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/staves/filter.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staff center line filter.
Identifies candidates for the center (third line) of a staff in each column of
an image. Using the estimated staffline distance, there must be black pixels
in the expected positions of the five staff lines. Some stafflines will be
covered by a black glyph, so the black run will be thicker than the expected
staffline thickness. To account for this, at least three lines must have white
pixels both above and below them.
The filtered image is used for a Hough transform (`hough.py`) to robustly
identify staves.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import segments
# The minimum number of columns that must have a candidate staff center line
# in a given row, to detect the row as a staff center line.
MIN_STAFF_SLICES = 0.25
def staff_center_filter(image,
staffline_distance,
staffline_thickness,
threshold=127):
"""Filters the image for candidate staff center lines.
Args:
image: The 2D tensor image.
staffline_distance: The estimated staffline distance. Scalar tensor.
staffline_thickness: The estimated staffline thickness. Scalar tensor.
threshold: Scalar tensor. Pixels below the threshold are black (possible
stafflines).
Returns:
A boolean tensor of the same shape as image. The candidate center staff
lines.
"""
image = image < threshold
# Add is not supported for unsigned ints, so use int8 instead of uint8.
# Dark: the image is dark where we expect a staffline.
dark_staffline_count = tf.zeros_like(image, tf.int8)
# Space: the image is light above and below where we expect a staffline,
# indicating a horizontal line.
space_staffline_count = tf.zeros_like(image, tf.int8)
for staffline_pos in range(-2, 3):
expected_y_line = staffline_pos * staffline_distance
# Allow each staffline to differ slightly from the expected position.
# The second and fourth lines can differ by 1 pixel, and the first and fifth
# lines can differ by 2 pixels.
# At each possible location, look for a dark pixel and light space above and
# below.
found_dark = tf.zeros_like(image, tf.bool)
found_space = tf.zeros_like(image, tf.bool)
y_adjustments = range(-abs(staffline_pos), abs(staffline_pos) + 1)
for y_adjustment in y_adjustments:
y_line = expected_y_line + y_adjustment
y_above = y_line - 2 * staffline_thickness
y_below = y_line + 2 * staffline_thickness
found_dark |= _shift_y(image, y_line)
found_space |= tf.logical_not(
tf.logical_or(_shift_y(image, y_above), _shift_y(image, y_below)))
dark_staffline_count += tf.cast(found_dark, tf.int8)
space_staffline_count += tf.cast(found_space, tf.int8)
return tf.logical_and(
tf.equal(dark_staffline_count, 5),
tf.greater_equal(space_staffline_count, 3))
def _shift_y(image, y_offset):
"""Shift the image vertically.
Args:
image: The 2D tensor image.
y_offset: The vertical offset for the image.
Returns:
The shifted image. Each pixel is shifted up or down by y_offset. Blank space
is filled with zeros.
"""
height = tf.shape(image)[0]
width = tf.shape(image)[1]
def invalid():
return image
def shift_up():
# y_offset is positive
sliced = image[y_offset:]
return tf.concat(
[sliced, tf.zeros([y_offset, width], dtype=image.dtype)], axis=0)
def shift_down():
# y_offset is negative
sliced = image[:y_offset]
return tf.concat([tf.zeros([-y_offset, width], dtype=image.dtype), sliced],
axis=0)
return tf.cond(height <= tf.abs(y_offset), invalid,
lambda: tf.cond(y_offset >= 0, shift_up, shift_down))
def _get_staff_ys(is_staff, staffline_thickness):
# Return the detected staves--segments in is_staff that are roughly not much
# bigger than staffline_thickness.
segment_ys, segment_sizes = segments.true_segments_1d(is_staff)
return tf.boolean_mask(segment_ys, segment_sizes <= staffline_thickness * 2)
================================================
FILE: moonlight/staves/hough.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Filtered hough staff detector.
Runs the staff center filter (see `filter.py`), and then uses the Hough
transform of the filtered image to detect nearly-horizontal lines (theta is
close to `pi / 2`).
The process is repeated for each unique detected staffline distance.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import tensorflow as tf
from moonlight.staves import base
from moonlight.staves import filter as staves_filter
from moonlight.staves import staffline_distance as distance
from moonlight.util import memoize
from moonlight.vision import hough
# The minimum number of columns that must have a candidate staff center line
# in a given row, to detect the row as a staff center line.
MIN_STAFF_SLICES = 0.25
DEFAULT_MAX_ABS_THETA = math.pi / 50
DEFAULT_NUM_THETA = 51
class FilteredHoughStaffDetector(base.BaseStaffDetector):
"""Filtered hough staff detector.
Runs the staff center filter (see `filter.py`), and then uses the Hough
transform of the filtered image to detect nearly-horizontal lines (theta is
close to `pi / 2`).
"""
def __init__(self,
image=None,
max_abs_theta=DEFAULT_MAX_ABS_THETA,
num_theta=DEFAULT_NUM_THETA):
"""Filtered hough staff detector.
Args:
image: The image. If None, a placeholder will be created.
max_abs_theta: The maximum deviation of the angle for the staff from the
horizontal, in radians.
num_theta: The number of thetas to be detected, between `pi/2 -
max_abs_theta` and `pi/2 + max_abs_theta`.
"""
super(FilteredHoughStaffDetector, self).__init__(image)
staffline_distance, staffline_thickness = (
distance.estimate_staffline_distance_and_thickness(self.image))
self.estimated_staffline_distance = staffline_distance
self.estimated_staffline_thickness = staffline_thickness
self.max_abs_theta = float(max_abs_theta)
self.num_theta = int(num_theta)
@property
def staves(self):
staves, _ = self._data
return staves
@property
def staffline_distance(self):
_, staffline_distance = self._data
return staffline_distance
@property
def staffline_thickness(self):
return self.estimated_staffline_thickness
@property
@memoize.MemoizedFunction
def _data(self):
def detection_loop_body(i, staves, staffline_distances):
"""Per-staffline-distance staff detection loop.
Args:
i: The index of the current staffline distance to use.
staves: The current staves tensor of shape (N, 2, 2).
staffline_distances: The current staffline distance tensor. 1D with
length N.
Returns:
i + 1.
staves concatd with any newly detected staves.
staffline_distance with the current staffline distance appended for each
new staff.
"""
current_staffline_distance = self.estimated_staffline_distance[i]
current_staves = _SingleSizeFilteredHoughStaffDetector(
self.image, current_staffline_distance,
self.estimated_staffline_thickness, self.max_abs_theta,
self.num_theta).staves
staves = tf.concat([staves, current_staves], axis=0)
staffline_distances = tf.concat([
staffline_distances,
tf.tile([current_staffline_distance],
tf.shape(staves)[0:1]),
],
axis=0)
return i + 1, staves, staffline_distances
num_staffline_distances = tf.shape(self.estimated_staffline_distance)[0]
_, staves, staffline_distances = tf.while_loop(
lambda i, _, __: tf.less(i, num_staffline_distances),
detection_loop_body, [
tf.constant(0),
tf.zeros([0, 2, 2], tf.int32),
tf.zeros([0], tf.int32)
],
shape_invariants=[
tf.TensorShape(()),
tf.TensorShape([None, 2, 2]),
tf.TensorShape([None])
],
parallel_iterations=1)
# Sort by y0.
order, = _argsort(staves[:, 0, 1])
staves = tf.gather(staves, order)
staffline_distances = tf.gather(staffline_distances, order)
return staves, staffline_distances
class _SingleSizeFilteredHoughStaffDetector(object):
"""Filtered hough staff detector for a single staffline distance size.
This is run in a loop by `FilteredHoughStaffDetector` in order to cover all of
the detected staffline distances.
Runs the staff center filter (see `filter.py`), and then uses the Hough
transform of the filtered image to detect nearly-horizontal lines (theta is
close to `pi / 2`).
"""
def __init__(self, image, staffline_distance, staffline_thickness,
max_abs_theta, num_theta):
"""Filtered hough staff detector.
Args:
image: The image. If None, a placeholder will be created.
staffline_distance: The single staffline distance scalar to use.
staffline_thickness: The staffline thickness.
max_abs_theta: The maximum deviation of the angle for the staff from the
horizontal, in radians.
num_theta: The number of thetas to be detected, between `pi/2 -
max_abs_theta` and `pi/2 + max_abs_theta`.
"""
self.image = image
self.estimated_staffline_distance = staffline_distance
self.estimated_staffline_thickness = staffline_thickness
self.max_abs_theta = float(max_abs_theta)
self.num_theta = int(num_theta)
# Memoize this to not re-compute "staves" as a different tensor each time this
# property is referenced. TF's common subexpression elimination doesn't seem
# to handle this case, maybe because we have too many ops.
@property
def staves(self):
"""The staves detected for a single staffline distance.
Returns:
A staves tensor of shape (N, 2, 2).
"""
height = tf.shape(self.image)[0]
width = tf.shape(self.image)[1]
staff_center = staves_filter.staff_center_filter(
self.image, self.estimated_staffline_distance,
self.estimated_staffline_thickness)
all_thetas = tf.linspace(math.pi / 2 - self.max_abs_theta,
math.pi / 2 + self.max_abs_theta, self.num_theta)
hough_bins = hough.hough_lines(staff_center, all_thetas)
staff_rhos, staff_thetas = hough.hough_peaks(
hough_bins,
all_thetas,
minval=MIN_STAFF_SLICES * tf.cast(width, tf.float32),
invalidate_distance=self.estimated_staffline_distance * 4)
num_staves = tf.shape(staff_rhos)[0]
# Interpolate the start and end points for the staff center line.
x0 = tf.zeros([num_staves], tf.int32)
y0 = tf.cast(
tf.cast(staff_rhos, tf.float32) / tf.sin(staff_thetas), tf.int32)
x1 = tf.fill([num_staves], width - 1)
y1 = tf.cast((tf.cast(staff_rhos, tf.float32) -
tf.cast(width - 1, tf.float32) * tf.cos(staff_thetas)) /
tf.sin(staff_thetas), tf.int32)
# Cut out staves which have a start or end y outside of the image.
is_valid = tf.logical_and(
tf.logical_and(0 <= y0, y0 < height),
tf.logical_and(0 <= y1, y1 < height))
staves = tf.reshape(tf.stack([x0, y0, x1, y1], axis=1), [-1, 2, 2])
return tf.boolean_mask(staves, is_valid)
# TODO(ringw): Add tf.argsort.
def _argsort(values):
return tf.py_func(np.argsort, [values], [tf.int64])
================================================
FILE: moonlight/staves/projection.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A naive horizontal projection-based staff detector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import scipy.ndimage
import tensorflow as tf
from moonlight.staves import base
from moonlight.util import memoize
class ProjectionStaffDetector(base.BaseStaffDetector):
"""A naive staff detector that uses horizontal projections.
Detects peaks in the number of black pixels in each row, which should
correspond to staff lines.
"""
staves_tensor = None
staffline_distance_tensor = None
def __init__(self, image=None):
super(ProjectionStaffDetector, self).__init__(image)
projection = tf.reduce_sum(tf.cast(self.image <= 127, tf.int32), 1)
width = tf.shape(self.image)[1]
min_num_dark_pixels = width // 2
staff_lines = projection > min_num_dark_pixels
staves, staffline_distance, staffline_thickness = tf.py_func(
_projection_to_staves, [staff_lines, width],
[tf.int32, tf.int32, tf.int32])
self.staves_tensor = staves
self.staffline_distance_tensor = staffline_distance
self.staffline_thickness_tensor = staffline_thickness
@property
@memoize.MemoizedFunction
def staves(self):
return self.staves_tensor
@property
@memoize.MemoizedFunction
def staffline_distance(self):
return self.staffline_distance_tensor
@property
@memoize.MemoizedFunction
def staffline_thickness(self):
return self.staffline_thickness_tensor
def _projection_to_staves(projection, width):
"""Pure python implementation of projection-based staff detection."""
labels, num_labels = scipy.ndimage.measurements.label(projection)
current_staff = []
staff_center_lines = []
staffline_distance = []
staffline_thicknesses = []
for line in range(1, num_labels + 1):
line_start = np.where(labels == line)[0].min()
line_end = np.where(labels == line)[0].max()
staffline_thickness = line_end - line_start + 1
line_center = np.int32(round((line_start + line_end) / 2.0))
current_staff.append(line_center)
if len(current_staff) > 5:
del current_staff[0]
if len(current_staff) == 5:
dists = np.array([
current_staff[1] - current_staff[0],
current_staff[2] - current_staff[1],
current_staff[3] - current_staff[2],
current_staff[4] - current_staff[3]
])
if np.max(dists) - np.min(dists) < 3:
staff_center = round(np.mean(current_staff))
staff_center_lines.append([[0, staff_center], [width - 1,
staff_center]])
staffline_distance.append(round(np.mean(dists)))
staffline_thicknesses.append(staffline_thickness)
staffline_thickness = (
np.median(staffline_thicknesses).astype(np.int32)
if staffline_thicknesses else np.int32(1))
return (np.array(staff_center_lines,
np.int32), np.array(staffline_distance,
np.int32), staffline_thickness)
================================================
FILE: moonlight/staves/removal.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staffline removal for glyph classification."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import memoize
from moonlight.util import segments
# The number of lines to remove above and below the staff center line. This
# removes the 5 staff lines, and 4 ledger lines (if present) above and below.
LINES_TO_REMOVE_ABOVE_AND_BELOW = 6
class StaffRemover(object):
"""Removes staff lines for glyph classification.
Identifies and removes short vertical runs where we expect the staff lines.
This means that the extracted staffline images for classification are more
consistent, whether they are centered on the line or halfway between lines.
"""
def __init__(self, staff_detector, threshold=127):
self.staff_detector = staff_detector
self.threshold = threshold
@property
@memoize.MemoizedFunction
def remove_staves(self):
"""Returns the page with staff lines removed.
Returns:
An image of the same size as `self.staff_detector.image`, with staff lines
erased (set to white, 255).
"""
image = tf.convert_to_tensor(self.staff_detector.image)
height = tf.shape(image)[0]
width = tf.shape(image)[1]
# Max height of a run length that can be removed. Runs should have height
# around staffline_thickness.
max_runlength = self.staff_detector.staffline_thickness * 2
# Calculate the expected y position of each staff line for each staff and
# each column of the image.
staff_center_ys = self.staff_detector.staves_interpolated_y
all_staffline_center_ys = (
staff_center_ys[:, None, :] +
self.staff_detector.staffline_distance[:, None, None] *
tf.range(-LINES_TO_REMOVE_ABOVE_AND_BELOW,
LINES_TO_REMOVE_ABOVE_AND_BELOW + 1)[None, :, None])
ys = tf.range(height)
def _process_column(i):
"""Removes staves from a single column of the image.
Args:
i: The index of the column to remove.
Returns:
The single column of the image with staff lines erased.
"""
column = image[:, i]
# Identify runs in the column that correspond to staff lines and can be
# erased.
runs, run_lengths = segments.true_segments_1d(column < self.threshold)
column_staffline_ys = all_staffline_center_ys[:, :, i]
# The run center has to be within staffline_thickness of a staff line.
run_matches_staffline = tf.less_equal(
tf.reduce_min(
tf.abs(runs[:, None, None] - column_staffline_ys[None, :, :]),
axis=[1, 2]), self.staff_detector.staffline_thickness)
keep_run = tf.logical_and(run_lengths < max_runlength,
run_matches_staffline)
keep_run.set_shape([None])
runs = tf.boolean_mask(runs, keep_run)
run_lengths = tf.boolean_mask(run_lengths, keep_run)
def do_process_column(runs, run_lengths):
"""Process the column if there are any runs matching staff lines.
Args:
runs: The center of each vertical run.
run_lengths: The length of each vertical run.
Returns:
The column of the image with staff lines erased.
"""
# Erase ys that belong to a run corresponding to a staff line.
y_run_pair_distance = tf.abs(ys[:, None] - runs[None, :])
y_runs = tf.argmin(y_run_pair_distance, axis=1)
y_run_distance = tf.reduce_min(y_run_pair_distance, axis=1)
y_run_lengths = tf.gather(run_lengths, y_runs)
erase_y = tf.less_equal(y_run_distance, tf.floordiv(y_run_lengths, 2))
white_column = tf.fill(tf.shape(column), tf.constant(255, tf.uint8))
return tf.where(erase_y, white_column, column)
return tf.cond(
tf.shape(runs)[0] > 0, lambda: do_process_column(runs, run_lengths),
lambda: column)
return tf.transpose(
tf.map_fn(
_process_column,
tf.range(width),
name="staff_remover",
dtype=tf.uint8))
================================================
FILE: moonlight/staves/removal_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for staff removal."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import numpy as np
import tensorflow as tf
from moonlight import image as omr_image
from moonlight import staves
from moonlight.staves import removal
from moonlight.staves import staffline_distance
class RemovalTest(tf.test.TestCase):
def test_corpus_image(self):
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
image_t = omr_image.decode_music_score_png(tf.read_file(filename))
remover = removal.StaffRemover(staves.StaffDetector(image_t))
with self.test_session() as sess:
removed, image = sess.run([remover.remove_staves, image_t])
self.assertFalse(np.allclose(removed, image))
# If staff removal runs successfully, we should be unable to estimate the
# staffline distance from the staves-removed image.
est_staffline_distance, est_staffline_thickness = sess.run(
staffline_distance.estimate_staffline_distance_and_thickness(removed))
print(est_staffline_distance)
self.assertAllEqual([], est_staffline_distance)
self.assertEqual(-1, est_staffline_thickness)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/staves/staff_processor.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adds staff location information to the Page.
The Page initially contains a single staff system with only glyphs, and this
adds the location of each staff.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six import moves
class StaffProcessor(object):
def __init__(self, structure, staffline_extractor):
self.staff_detector = structure.staff_detector
self.staffline_extractor = staffline_extractor
def apply(self, page):
"""Adds staff location information to the Page message."""
assert len(page.system) == 1, ('Page must initially have a single staff '
'system')
assert len(page.system[0].staff) == len(self.staff_detector.staves), (
'Glyphs page must have the same number of staves as the staff detector')
staves_arr = self.staff_detector.staves
for i, staff in enumerate(page.system[0].staff):
staff.staffline_distance = self.staff_detector.staffline_distance[i]
for j in moves.xrange(staves_arr.shape[1]):
if (0 < j and j + 1 < staves_arr.shape[1] and
staves_arr[i, j - 1, 0] == staves_arr[i, j, 0] and
staves_arr[i, j, 0] == staves_arr[i, j + 1, 0]):
continue
point = staff.center_line.add()
point.x = staves_arr[i, j, 0]
point.y = staves_arr[i, j, 1]
# Scale the glyph x coordinates back for the original image.
if self.staffline_extractor:
# The height of an extracted slice of the image before scaling.
staffline_orig_height = (
staff.staffline_distance *
self.staffline_extractor.staffline_distance_multiple)
# The scale factor from the scaled staffline images to the original.
staffline_scale = (
staffline_orig_height / self.staffline_extractor.target_height)
for glyph in staff.glyph:
glyph.x = int(round(staffline_scale * glyph.x))
return page
================================================
FILE: moonlight/staves/staff_processor_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the staff page processor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from absl.testing import absltest
import numpy as np
from moonlight import engine
from moonlight import structure as structure_module
from moonlight.glyphs import testing as glyphs_testing
from moonlight.protobuf import musicscore_pb2
from moonlight.staves import staff_processor
from moonlight.staves import testing as staves_testing
class StaffProcessorTest(absltest.TestCase):
def testGetPage_x_scale(self):
# Random staffline images matching the dimensions of PREDICTIONS.
dummy_stafflines = np.random.random((2, 3, 5, 6))
classifier = glyphs_testing.DummyGlyphClassifier(glyphs_testing.PREDICTIONS)
image = np.random.randint(0, 255, (30, 20), dtype=np.uint8)
staves = staves_testing.FakeStaves(
image_t=image,
staves_t=np.asarray([[[0, 10], [19, 10]], [[0, 20], [19, 20]]],
np.int32),
staffline_distance_t=np.asarray([5, 20], np.int32),
staffline_thickness_t=np.asarray(1, np.int32))
structure = structure_module.create_structure(image,
lambda unused_image: staves)
class DummyStafflineExtractor(object):
"""A placeholder for StafflineExtractor.
It only contains the constants necessary to scale the x coordinates.
"""
staffline_distance_multiple = 2
target_height = 10
omr = engine.OMREngine(lambda _: classifier)
page = omr.process_image(
# Feed in a dummy image. It doesn't matter because FakeStaves has
# hard-coded staff values.
np.random.randint(0, 255, (100, 100)),
process_structure=False)
page = staff_processor.StaffProcessor(structure,
DummyStafflineExtractor()).apply(page)
self.assertEqual(len(page.system[0].staff), 2)
# The first staff has a staffline distance of 5.
# The extracted staffline slices have an original height of
# staffline_distance * staffline_distance_multiple (10), which equals
# target_height here, so there is no scaling.
self.assertEqual(
musicscore_pb2.Staff(glyph=page.system[0].staff[0].glyph),
glyphs_testing.GLYPHS_PAGE.system[0].staff[0])
# Glyphs in the second staff have a scaled x coordinate.
self.assertEqual(
len(page.system[0].staff[1].glyph),
len(glyphs_testing.GLYPHS_PAGE.system[0].staff[1].glyph))
for glyph in glyphs_testing.GLYPHS_PAGE.system[0].staff[1].glyph:
expected_glyph = copy.deepcopy(glyph)
# The second staff has a staffline distance of 20. The extracted staffline
# slice would be 4 times the size of the scaled staffline, so x
# coordinates are scaled by 4. Also, the glyphs may be in a different
# order.
expected_glyph.x *= 4
self.assertIn(expected_glyph, page.system[0].staff[1].glyph)
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/staves/staffline_distance.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements staffline distance estimation.
The staffline distance is the vertical distance between consecutive lines in a
staff, which is assumed to be uniform for a single staff on a scanned music
score. The staffline thickness is the vertical height of each staff line, which
is assumed to be uniform for the entire page.
Uses the algorithm described in [1], which creates a histogram of possible
staffline distance and thickness values for the entire image, based on the
vertical run-length encoding [2]. Each consecutive pair of black and white runs
contributes to the staffline distance histogram (because they may be the
staffline followed by an unobstructed space, or vice versa). We then take the
argmax of the histogram, and find candidate staff line runs. These runs must be
before or after another run, such that the sum of the run lengths is the
detected staffline distance. Then the black run is considered to be an actual
staff line, and its length contributes to the staffline thickness histogram.
Although we use a single staffline distance value for staffline thickness
detection, we may detect multiple distinct peaks in the histogram. We then run
staff detection using each distinct peak value, to detect smaller staves with an
unusual size, e.g. ossia parts [3].
[1] Cardoso, Jaime S., and Ana Rebelo. "Robust staffline thickness and distance
estimation in binary and gray-level music scores." 20th International
Conference on Pattern Recognition (ICPR). IEEE, 2010.
[2] https://en.wikipedia.org/wiki/Run-length_encoding
[3] https://en.wikipedia.org/wiki/Ossia
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import run_length
from moonlight.util import segments
# The size of the histograms. Normal values for the peak are around 20 for
# staffline distance, and 2-3 for staffline thickness.
_MAX_STAFFLINE_DISTANCE_THICKNESS_VALUE = 256
# The minimum number of votes for a staffline distance bin. We expect images to
# be a reasonable size (> 100x100), and want to ensure we exclude images that
# don't contain any staves.
_MIN_STAFFLINE_DISTANCE_SCORE = 10000
# The maximum allowed number of unique staffline distances. If more staffline
# distances are detected, return an empty list instead.
_MAX_ALLOWED_UNIQUE_STAFFLINE_DISTANCES = 3
_STAFFLINE_DISTANCE_INVALIDATE_DISTANCE = 1
_STAFFLINE_THICKNESS_INVALIDATE_DISTANCE = 1
_PEAK_CUTOFF = 0.5
def _single_peak(values, relative_cutoff, minval, invalidate_distance):
"""Takes a single peak if it is high enough compared to all other peaks.
Args:
values: 1D tensor of values to take the peaks on.
relative_cutoff: The fraction of the highest peak which all other peaks
should be below.
minval: The peak should have at least this value.
invalidate_distance: Exclude values that are up to invalidate_distance away
from the peak.
Returns:
The index of the single peak in `values`, or -1 if there is not a single
peak that satisfies `relative_cutoff`.
"""
relative_cutoff = tf.convert_to_tensor(relative_cutoff, tf.float32)
# argmax is safe because the histogram is always non-empty.
peak = tf.to_int32(tf.argmax(values))
# Take values > minval away from the peak.
other_values = tf.boolean_mask(
values,
tf.greater(
tf.abs(tf.range(tf.shape(values)[0]) - peak), invalidate_distance))
should_take_peak = tf.logical_and(
tf.greater_equal(values[peak], minval),
# values[peak] * relative_cutoff must be >= other_values.
tf.reduce_all(
tf.greater_equal(
tf.to_float(values[peak]) * relative_cutoff,
tf.to_float(other_values))))
return tf.cond(should_take_peak, lambda: peak, lambda: -1)
def _estimate_staffline_distance(columns, lengths):
"""Estimates the staffline distances of a music score.
Args:
columns: 1D array. The column indices of each vertical run.
lengths: 1D array. The length of each consecutive vertical run.
Returns:
A 1D tensor of possible staffline distances in the image.
"""
with tf.name_scope('estimate_staffline_distance'):
run_pair_lengths = lengths[:-1] + lengths[1:]
keep_pair = tf.equal(columns[:-1], columns[1:])
staffline_distance_histogram = tf.bincount(
tf.boolean_mask(run_pair_lengths, keep_pair),
# minlength required to avoid errors on a fully white image.
minlength=_MAX_STAFFLINE_DISTANCE_THICKNESS_VALUE,
maxlength=_MAX_STAFFLINE_DISTANCE_THICKNESS_VALUE)
peaks = segments.peaks(
staffline_distance_histogram,
minval=_MIN_STAFFLINE_DISTANCE_SCORE,
invalidate_distance=_STAFFLINE_DISTANCE_INVALIDATE_DISTANCE)
def do_filter_peaks():
"""Process the peaks if they are non-empty.
Returns:
The filtered peaks. Peaks below the cutoff when compared to the highest
peak are removed. If the peaks are invalid, then an empty list is
returned.
"""
histogram_size = tf.shape(staffline_distance_histogram)[0]
peak_values = tf.to_float(tf.gather(staffline_distance_histogram, peaks))
max_value = tf.reduce_max(peak_values)
allowed_peaks = tf.greater_equal(peak_values,
max_value * tf.constant(_PEAK_CUTOFF))
# Check if there are too many detected staffline distances, and we should
# return an empty list.
allowed_peaks &= tf.less_equal(
tf.reduce_sum(tf.to_int32(allowed_peaks)),
_MAX_ALLOWED_UNIQUE_STAFFLINE_DISTANCES)
# Check if any values sufficiently far away from the peaks are too high.
# This means the peaks are not sharp enough and we should return an empty
# list.
far_from_peak = tf.greater(
tf.reduce_min(
tf.abs(tf.range(histogram_size)[None, :] - peaks[:, None]),
axis=0), _STAFFLINE_DISTANCE_INVALIDATE_DISTANCE)
allowed_peaks &= tf.less(
tf.to_float(
tf.reduce_max(
tf.boolean_mask(staffline_distance_histogram,
far_from_peak))),
max_value * tf.constant(_PEAK_CUTOFF))
return tf.boolean_mask(peaks, allowed_peaks)
return tf.cond(
tf.greater(tf.shape(peaks)[0], 0), do_filter_peaks,
lambda: tf.identity(peaks))
def _estimate_staffline_thickness(columns, values, lengths, staffline_distance):
"""Estimates the staffline thickness of a music score.
Args:
columns: 1D array. The column indices of each consecutive vertical run.
values: 1D array. The value (0 or 1) of each vertical run.
lengths: 1D array. The length of each vertical run.
staffline_distance: A 1D tensor of the possible staffline distances in the
image. One of the distances may be chosen arbitrarily.
Returns:
A scalar tensor with the staffline thickness for the entire page, or -1 if
it could not be estimated (staffline_distance is empty, or there are not
enough runs to estimate the staffline thickness).
"""
with tf.name_scope('estimate_staffline_thickness'):
def do_estimate():
"""Compute the thickness if distance detection was successful."""
run_pair_lengths = lengths[:-1] + lengths[1:]
# Use the smallest staffline distance to estimate the staffline thickness.
keep_pair = tf.logical_and(
tf.equal(columns[:-1], columns[1:]),
tf.equal(run_pair_lengths, staffline_distance[0]))
run_pair_lengths = tf.boolean_mask(run_pair_lengths, keep_pair)
start_values = tf.boolean_mask(values[:-1], keep_pair)
start_lengths = tf.boolean_mask(lengths[:-1], keep_pair)
end_lengths = tf.boolean_mask(lengths[1:], keep_pair)
staffline_thickness_values = tf.where(
tf.not_equal(start_values, 0), start_lengths, end_lengths)
staffline_thickness_histogram = tf.bincount(
staffline_thickness_values,
minlength=_MAX_STAFFLINE_DISTANCE_THICKNESS_VALUE,
maxlength=_MAX_STAFFLINE_DISTANCE_THICKNESS_VALUE)
return _single_peak(
staffline_thickness_histogram,
_PEAK_CUTOFF,
minval=1,
invalidate_distance=_STAFFLINE_THICKNESS_INVALIDATE_DISTANCE)
return tf.cond(
tf.greater(tf.shape(staffline_distance)[0], 0), do_estimate,
lambda: tf.constant(-1, tf.int32))
def estimate_staffline_distance_and_thickness(image, threshold=127):
"""Estimates the staffline distance and thickness of a music score.
Args:
image: A 2D tensor (HW) and type uint8.
threshold: The global threshold for the image.
Returns:
The estimated vertical distance(s) from the center of one staffline to the
next in the music score. 1D tensor containing all unique values of the
estimated staffline distance for each staff.
The estimated staffline thickness of the music score.
Raises:
TypeError: If `image` is an invalid type.
"""
image = tf.convert_to_tensor(image, name='image', dtype=tf.uint8)
threshold = tf.convert_to_tensor(threshold, name='threshold', dtype=tf.uint8)
if image.dtype.base_dtype != tf.uint8:
raise TypeError('Invalid dtype %s.' % image.dtype)
columns, values, lengths = run_length.vertical_run_length_encoding(
tf.less(image, threshold))
staffline_distance = _estimate_staffline_distance(columns, lengths)
staffline_thickness = _estimate_staffline_thickness(columns, values, lengths,
staffline_distance)
# staffline_thickness may be -1 even if staffline_distance > 0. Fix it so
# that we can check either one to determine whether there are staves.
staffline_distance = tf.cond(
tf.equal(staffline_thickness, -1), lambda: tf.zeros([0], tf.int32),
lambda: tf.identity(staffline_distance))
return staffline_distance, staffline_thickness
================================================
FILE: moonlight/staves/staffline_distance_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for staffline distance estimation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
from moonlight.image import decode_music_score_png
from moonlight.staves import staffline_distance
class StafflineDistanceTest(tf.test.TestCase):
def testCorpusImage(self):
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
image_contents = open(filename, 'rb').read()
image_t = decode_music_score_png(tf.constant(image_contents))
staffdist_t, staffthick_t = (
staffline_distance.estimate_staffline_distance_and_thickness(image_t,))
with self.test_session() as sess:
staffdist, staffthick = sess.run((staffdist_t, staffthick_t))
# Manually determined values for the image.
self.assertAllEqual(staffdist, [16])
self.assertEquals(staffthick, 2)
def testZeros(self):
# All white (0) shouldn't be picked up as a music score.
image_t = tf.zeros((512, 512), dtype=tf.uint8)
staffdist_t, staffthick_t = (
staffline_distance.estimate_staffline_distance_and_thickness(image_t))
with self.test_session() as sess:
staffdist, staffthick = sess.run((staffdist_t, staffthick_t))
self.assertAllEqual(staffdist, [])
self.assertEqual(staffthick, -1)
def testSpeckles(self):
# Random speckles shouldn't be picked up as a music score.
tf.set_random_seed(1234)
image_t = tf.where(
tf.random_uniform((512, 512)) < 0.1,
tf.fill((512, 512), tf.constant(255, tf.uint8)),
tf.fill((512, 512), tf.constant(0, tf.uint8)))
staffdist_t, staffthick_t = (
staffline_distance.estimate_staffline_distance_and_thickness(image_t))
with self.test_session() as sess:
staffdist, staffthick = sess.run((staffdist_t, staffthick_t))
self.assertAllEqual(staffdist, [])
self.assertEqual(staffthick, -1)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/staves/staffline_extractor.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracts horizontal slices from a staff for glyph classification."""
# TODO(ringw): Rename StafflineExtractor to PositionExtractor. Stafflines in
# this context should be renamed "extracted positions" to avoid confusion.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import enum
from moonlight import image as image_module
from moonlight import staves as staves_module
from moonlight.staves import removal
from six import moves
import tensorflow as tf
DEFAULT_TARGET_HEIGHT = 18
DEFAULT_NUM_SECTIONS = 19
DEFAULT_STAFFLINE_DISTANCE_MULTIPLE = 3
class Axes(enum.IntEnum):
STAFF = 0
POSITION = 1
Y = 2
X = 3
def get_staffline(y_position, extracted_staff_arr):
"""Gets the staffline of the extracted staff.
Args:
y_position: The staffline position--the relative number of notes from the
3rd line on the staff.
extracted_staff_arr: An extracted staff NumPy array, e.g.
`StafflineExtractor.extract_staves()[0].eval()` (`StafflineExtractor`
returns multiple staves).
Returns:
The correct staffline from `extracted_staff_arr`, with shape
`(target_height, image_width)`.
Raises:
ValueError: If the `y_position` is out of bounds in either direction.
"""
return extracted_staff_arr[y_position_to_index(y_position,
len(extracted_staff_arr))]
def y_position_to_index(y_position, num_stafflines):
index = num_stafflines // 2 - y_position
if not 0 <= index < num_stafflines:
raise ValueError('y_position %d too large for %d stafflines' %
(y_position, num_stafflines))
return index
class StafflineExtractor(object):
"""Extracts horizontal slices from a staff for glyph classification.
Glyphs must be centered on either a staff line or a staff space (halfway
between staff lines). For classification, a window is extracted with height
2*staffline_distance around a staffline or staff space. If num_sections is 9,
extracts the five staff lines and the staff spaces between them.
The slice is scaled proportionally to the staffline distance, making the
output height equal to target_height, so that the glyph classifier is
scale-invariant.
This class is used in inference as part of a larger TF graph. See
StafflinePatchExtractor for training.
"""
def __init__(self,
image,
staves,
target_height=DEFAULT_TARGET_HEIGHT,
num_sections=DEFAULT_NUM_SECTIONS,
staffline_distance_multiple=DEFAULT_STAFFLINE_DISTANCE_MULTIPLE):
"""Create the staffline extractor.
Args:
image: A uint8 tensor of shape (height, width). The background (usually
white) must have a value of 0.
staves: An instance of base.BaseStaffDetector.
target_height: The height of the scaled output windows.
num_sections: The number of stafflines to extract.
staffline_distance_multiple: The height of the extracted staffline, in
multiples of the staffline distance. For example, a notehead should fit
in a staffline distance multiple of 1, because it starts and ends
vertically on a staff line. However, other glyphs may need more space
above and below to classify accurately.
"""
self.float_image = tf.cast(image, tf.float32) / 255.
self.staves = staves
self.target_height = target_height
self.num_sections = num_sections
self.staffline_distance_multiple = staffline_distance_multiple
# Calculate the maximum width needed.
min_staffline_distance = tf.reduce_min(staves.staffline_distance)
self.target_width = self._get_resized_width(min_staffline_distance)
def extract_staves(self):
"""Extracts stafflines from all staves in the image.
Returns:
A float32 Tensor of shape
(num_staves, num_sections, target_height, slice_width). If the staffline
distance is inconsistent between staves, smaller staves will be padded
on the right with zeros.
"""
# Only map if we have any staves, otherwise return an empty array with the
# correct dimensionality.
def do_extract_staves():
"""Actually performs staffline extraction if we have any staves.
Returns:
The stafflines tensor. See outer function doc.
"""
staff_ys = self.staves.staves_interpolated_y
def extract_staff(i):
def extract_staffline_by_index(j):
return self._extract_staffline(staff_ys[i],
self.staves.staffline_distance[i], j)
return tf.map_fn(
extract_staffline_by_index,
tf.range(-(self.num_sections // 2), self.num_sections // 2 + 1),
dtype=tf.float32)
return tf.map_fn(
extract_staff,
tf.range(tf.shape(self.staves.staves)[0]),
dtype=tf.float32)
# Shape of the empty stafflines tensor, if no staves are present.
empty_shape = (0, self.num_sections, self.target_height, 0)
stafflines = tf.cond(
tf.shape(self.staves.staves)[0] > 0,
do_extract_staves,
# Otherwise, return an empty stafflines array.
lambda: tf.zeros(empty_shape, tf.float32))
# We need target_height to be statically known for e.g. `util/patches.py`.
stafflines.set_shape((None, self.num_sections, self.target_height, None))
return stafflines
def _extract_staffline(self, staff_y, staffline_distance, staffline_num):
"""Extracts a single staffline from a single staff."""
# Use a float image on a 0.0-1.0 scale for classification.
image_shape = tf.shape(self.float_image)
height = image_shape[0] # Can't unpack a tensor object.
width = image_shape[1]
# Calculate the height of the extracted staffline in the unscaled image.
staff_window = self._get_staffline_window_size(staffline_distance)
# Calculate the coordinates to extract for the window.
# Note: tf.meshgrid uses xs before ys by default, but y is the 0th axis
# for indexing.
xs, ys = tf.meshgrid(
tf.range(width),
tf.range(staff_window) - (staff_window // 2))
# ys are centered around 0. Add the staff_y, repeating along the
# 0th axis.
ys += tf.tile(staff_y[None, :], [staff_window, 1])
# Add the offset for the staff line within the staff.
# Round up in case the y position is not whole (in between staff lines with
# an odd staffline distance). This puts the center of the staff space closer
# to the center of the window.
ys += tf.cast(
tf.ceil(tf.truediv(staffline_num * staffline_distance, 2)), tf.int32)
invalid = tf.logical_not((0 <= ys) & (ys < height) & (0 <= xs)
& (xs < width))
# Use a coordinate of (0, 0) for pixels outside of the original image.
# We will then fill in those pixels with zeros.
ys = tf.where(invalid, tf.zeros_like(ys), ys)
xs = tf.where(invalid, tf.zeros_like(xs), xs)
inds = tf.stack([ys, xs], axis=2)
staffline_image = tf.gather_nd(self.float_image, inds)
# Fill the pixels outside of the original image with zeros.
staffline_image = tf.where(invalid, tf.zeros_like(staffline_image),
staffline_image)
# Calculate the proportional width after scaling the height to
# self.target_height.
resized_width = self._get_resized_width(staffline_distance)
# Use area resizing because we expect the output to be smaller.
# Add extra axes, because we only have 1 image and 1 channel.
staffline_image = tf.image.resize_area(
staffline_image[None, :, :,
None], [self.target_height, resized_width])[0, :, :, 0]
# Pad to make the width consistent with target_width.
staffline_image = tf.pad(staffline_image,
[[0, 0], [0, self.target_width - resized_width]])
return staffline_image
def _get_resized_width(self, staffline_distance):
image_width = tf.shape(self.float_image)[1]
window_height = self._get_staffline_window_size(staffline_distance)
return tf.cast(
tf.round(tf.truediv(image_width * self.target_height, window_height)),
tf.int32)
def _get_staffline_window_size(self, staffline_distance):
return tf.to_int32(
tf.round(
tf.to_float(staffline_distance) *
tf.to_float(self.staffline_distance_multiple)))
class StafflinePatchExtractor(object):
"""Wraps the OMR TensorFlow graph and performs staff patch extraction.
While inference uses StafflineExtractor/Convolutional1DGlyphClassifier to
efficiently extract patches within the TF graph, StafflinePatchExtractor
encapsulates the TF graph necessary for extraction. Therefore, it is to be
used for training example extraction, where only staff detection and staffline
extraction are run in TF.
"""
def __init__(self,
num_sections=DEFAULT_NUM_SECTIONS,
patch_height=15,
patch_width=12,
run_options=None):
self.num_sections = num_sections
self.patch_height = patch_height
self.patch_width = patch_width
self.run_options = run_options
self.graph = tf.Graph()
with self.graph.as_default():
# Identifying information for the patch.
self.filename = tf.placeholder(tf.string, name='filename')
self.staff_index = tf.placeholder(tf.int64, name='staff_index')
self.y_position = tf.placeholder(tf.int64, name='y_position')
image = image_module.decode_music_score_png(tf.read_file(self.filename))
staff_detector = staves_module.StaffDetector(image)
staff_remover = removal.StaffRemover(staff_detector)
extractor = StafflineExtractor(
staff_remover.remove_staves,
staff_detector,
num_sections=num_sections,
target_height=patch_height)
# Index into the staff strips array, where a y position of 0 is the center
# element. Positive positions count up (towards higher notes, towards the
# top of the image, and smaller indices into the array).
position_index = num_sections // 2 - self.y_position
self.all_stafflines = extractor.extract_staves()
# The entire extracted horizontal strip of the image.
self.staffline = self.all_stafflines[self.staff_index, position_index]
# Determine the scale for converting image x coordinates to the scaled
# staff strip from which the patch is extracted.
extracted_staff_strip_height = tf.shape(self.all_stafflines)[2]
unscaled_staff_strip_heights = tf.multiply(
DEFAULT_STAFFLINE_DISTANCE_MULTIPLE,
staff_detector.staffline_distance)
self.all_staffline_scales = tf.divide(
tf.to_float(extracted_staff_strip_height),
tf.to_float(unscaled_staff_strip_heights))
self.staffline_scale = self.all_staffline_scales[self.staff_index]
def extract_staff_strip(self, filename, staff_index, y_position):
"""Extracts an entire horizontal strip from the image.
Args:
filename: The absolute filename of the image.
staff_index: Index of the staff out of all staves on the page.
y_position: Note y position on the staff, on which to extract the strip.
The position starts out 0 for the staff center line, and grows more
positive for higher notes.
Returns:
A tuple of:
A wide strip of the image as a NumPy array.
The scale factor from the original image scale to the normalized staff
strip scale.
"""
return tf.get_default_session().run(
[self.staffline, self.staffline_scale],
feed_dict={
self.filename: filename,
self.staff_index: staff_index,
self.y_position: y_position,
},
options=self.run_options)
def extract_staff_patch(self, filename, staff_index, y_position, image_x):
"""Extracts a rectangular patch to be labeled.
Args:
filename: The absolute filename of the image.
staff_index: Index of the staff out of all staves on the page.
y_position: Note y position on the staff, on which to extract the strip.
image_x: The coordinate of the patch center x, in image coordinates.
Returns:
The rectangular NumPy array for the patch.
Raises:
ValueError: If the x coordinate is too close to the left or right edge of
the image to extract a full patch.
"""
staffline, scale = self.extract_staff_strip(filename, staff_index,
y_position)
staffline_x = int(round(image_x * scale))
patch_x_start = staffline_x + (-self.patch_width // 2)
patch_x_stop = staffline_x + self.patch_width // 2
if not (self.patch_width // 2 <= staffline_x <
staffline.shape[1] - self.patch_width // 2):
raise ValueError('image_x too close to bounds of image')
return staffline[:, patch_x_start:patch_x_stop]
def page_patch_iterator(self, filename):
"""Iterates over every patch on every staff.
Args:
filename: Path to a PNG file.
Returns:
A generator yielding pairs of patch id (with coordinates) and 2D
`np.ndarray` with the patch contents.
"""
all_stafflines, scales = tf.get_default_session().run(
[self.all_stafflines, self.all_staffline_scales],
feed_dict={self.filename: filename},
options=self.run_options)
def generator():
"""The patch generator.
Yields:
Every extracted patch, as a logical id and patch ndarray.
"""
for staff_index, staff in enumerate(all_stafflines):
scale = scales[staff_index]
for staffline_index, staffline in enumerate(staff):
y_position = self.num_sections // 2 - staffline_index
prev_image_x = None
for staffline_x_start in moves.xrange(staffline.shape[1] -
self.patch_width):
image_x = int(
round((staffline_x_start + self.patch_width // 2) / scale))
if image_x != prev_image_x:
patch_id = StafflinePatchExtractor.make_patch_id(
filename, staff_index, y_position, image_x)
patch = staffline[:, staffline_x_start:staffline_x_start +
self.patch_width]
yield patch_id, patch
prev_image_x = image_x
return generator()
@staticmethod
def make_patch_id(filename, staff_index, y_position, image_x):
"""Formats the short id uniquely identifying a patch in the training set."""
short_filename, _ = os.path.splitext(os.path.basename(filename))
# Format y_position as e.g. +2, +0, or -3.
return '{},{},{:+d},{}'.format(short_filename, staff_index, y_position,
image_x)
================================================
FILE: moonlight/staves/staffline_extractor_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for StafflineExtractor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import numpy as np
import tensorflow as tf
from moonlight import staves
from moonlight.staves import staffline_extractor
class StafflineExtractorTest(tf.test.TestCase):
def setUp(self):
# Small image with a single staff.
# pyformat: disable
self.single_staff_image = np.asarray(
[[1, 1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 1, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1]], np.uint8) * 255
def testExtractStaff(self):
image_t = tf.constant(self.single_staff_image, name='image')
detector = staves.ProjectionStaffDetector(image_t)
# The staffline distance is 3, so use a target height of 6 to avoid scaling
# the image.
extractor = staffline_extractor.StafflineExtractor(
image_t,
detector,
target_height=6,
num_sections=9,
staffline_distance_multiple=2)
with self.test_session():
stafflines = extractor.extract_staves().eval()
assert stafflines.shape == (1, 9, 6, 7)
# The top staff line is at a y-value of 2 because of rounding.
assert np.array_equal(
stafflines[0, 0],
np.concatenate((np.zeros((2, 7)), self.single_staff_image[:4] / 255.0)))
# The staff space is centered in the window.
assert np.array_equal(stafflines[0, 3],
self.single_staff_image[3:9] / 255.0)
# Staffline height is 3 and extracted strips have 2 staff line distances
# with a total height of 6, so the strip is not actually scaled.
self.assertTrue(
np.logical_or(np.isclose(stafflines, 1.), np.isclose(stafflines,
0.)).all())
def testFloatMultiple(self):
image_t = tf.constant(self.single_staff_image, name='image')
detector = staves.ProjectionStaffDetector(image_t)
extractor = staffline_extractor.StafflineExtractor(
image_t,
detector,
target_height=6,
num_sections=9,
staffline_distance_multiple=1.5)
with self.test_session():
stafflines = extractor.extract_staves().eval()
# Staff strip is scaled up because the input height is less. Some output
# pixels have aliasing.
self.assertAllEqual(stafflines.shape, (1, 9, 6, 10))
self.assertFalse(
np.logical_or(np.isclose(stafflines, 1.), np.isclose(stafflines,
0.)).all())
class StafflinePatchExtractorTest(tf.test.TestCase):
def testCompareIteratorWithSinglePatch(self):
# Patches from the iterator should be exactly equal to the patch when the
# coordinates from the id are given.
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
extractor = staffline_extractor.StafflinePatchExtractor()
with self.test_session(graph=extractor.graph):
single_patch = extractor.extract_staff_patch(filename, 3, 4, 800)
# Unwrap the single patch with the matching id in the iterator.
patch_from_iterator, = (
patch
for patch_id, patch in extractor.page_patch_iterator(filename)
if patch_id == 'IMSLP00747-000,3,+4,800'
)
self.assertAllClose(single_patch, patch_from_iterator)
# Patch contains some content.
self.assertAlmostEqual(single_patch.min(), 0, places=5)
self.assertAlmostEqual(single_patch.max(), 1, places=5)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/staves/testing.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Staff detection test utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from moonlight.staves import base
class FakeStaves(base.BaseStaffDetector):
"""Fake staff detector holding an arbitrary staves tensor.
Attributes:
image: The image.
staves_t: The staves given to the constructor. None may be given if the
staves are never checked (only the staffline distance).
staffline_distance_t: The estimated staffline distance. 1D tensor (values
for each staff) or None.
staffline_thickness_t: The estimated staffline thickness. Scalar tensor or
None.
"""
def __init__(self,
image_t,
staves_t,
staffline_distance_t=None,
staffline_thickness_t=None):
self.image = image_t
self.staves_t = staves_t
self.staffline_distance_t = staffline_distance_t
self.staffline_thickness_t = staffline_thickness_t
@property
def staves(self):
return self.staves_t
@property
def staffline_distance(self):
return self.staffline_distance_t
@property
def staffline_thickness(self):
return self.staffline_thickness_t
================================================
FILE: moonlight/structure/BUILD
================================================
# Description:
# Music score structure detection above the staff level.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "structure",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":beams",
":components",
":verticals",
"//moonlight/staves",
"//moonlight/staves:base",
"//moonlight/staves:removal",
# tensorflow dep
],
)
py_test(
name = "structure_test",
srcs = ["structure_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":structure",
# disable_tf2
"//moonlight:image",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "barlines",
srcs = ["barlines.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
],
)
py_test(
name = "barlines_test",
srcs = ["barlines_test.py"],
srcs_version = "PY2AND3",
deps = [
":barlines",
":beams",
":components",
":structure",
":verticals",
# disable_tf2
# absl/testing dep
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/staves:base",
# numpy dep
],
)
py_library(
name = "beams",
srcs = ["beams.py"],
srcs_version = "PY2AND3",
deps = [
":components",
"//moonlight/vision:morphology",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "beam_processor",
srcs = ["beam_processor.py"],
srcs_version = "PY2AND3",
deps = [
":components",
"//moonlight/glyphs:glyph_types",
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
],
)
py_library(
name = "components",
srcs = ["components.py"],
srcs_version = "PY2AND3",
deps = [
# enum34 dep
# tensorflow dep
# tensorflow.contrib.image py dep
],
)
py_test(
name = "components_test",
srcs = ["components_test.py"],
srcs_version = "PY2AND3",
deps = [
":components",
# disable_tf2
# tensorflow dep
],
)
py_library(
name = "section_barlines",
srcs = ["section_barlines.py"],
srcs_version = "PY2AND3",
deps = [
":barlines",
":components",
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
],
)
py_library(
name = "stems",
srcs = ["stems.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/glyphs:glyph_types",
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
],
)
py_test(
name = "stems_test",
srcs = ["stems_test.py"],
srcs_version = "PY2AND3",
deps = [
":beams",
":components",
":stems",
":structure",
":verticals",
# disable_tf2
# absl/testing dep
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/staves:base",
# numpy dep
],
)
py_library(
name = "verticals",
srcs = ["verticals.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/util:functional_ops",
"//moonlight/util:memoize",
"//moonlight/util:segments",
"//moonlight/vision:images",
"//moonlight/vision:morphology",
# numpy dep
# tensorflow dep
],
)
py_test(
name = "verticals_test",
srcs = ["verticals_test.py"],
srcs_version = "PY2AND3",
deps = [
":verticals",
# disable_tf2
"//moonlight/staves:testing",
# numpy dep
# tensorflow dep
],
)
================================================
FILE: moonlight/structure/__init__.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Holder for the page structure detectors.
`create_structure()` constructs the staff and verticals detectors with the
given callables. `Structure.compute()` is run to compute all structure in a
single TensorFlow graph, to increase parallelism.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight import staves
from moonlight.staves import base as staves_base
from moonlight.staves import removal
from moonlight.structure import beams as beams_module
from moonlight.structure import components as components_module
from moonlight.structure import verticals as verticals_module
def create_structure(image,
staff_detector=staves.StaffDetector,
beams=beams_module.Beams,
verticals=verticals_module.ColumnBasedVerticals,
components=components_module.from_staff_remover):
"""Constructs a Structure instance.
Constructs a staff detector and verticals with the given callables.
Args:
image: The image tensor.
staff_detector: A callable that accepts the image and returns a
StaffDetector.
beams: A callable that accept a StaffRemover and returns a Beams.
verticals: A callable that accepts the staff detector and returns a
verticals impl (e.g. ColumnBasedVerticals).
components: A callable that accepts a StaffRemover and returns a
ConnectedComponents.
Returns:
The Structure instance.
"""
with tf.name_scope('staff_detector'):
staff_detector = staff_detector(image)
with tf.name_scope('staff_remover'):
staff_remover = removal.StaffRemover(staff_detector)
with tf.name_scope('beams'):
beams = beams(staff_remover)
with tf.name_scope('verticals'):
verticals = verticals(staff_detector)
with tf.name_scope('components'):
components = components(staff_remover)
structure = Structure(
staff_detector,
beams,
verticals,
components,
image=image,
staff_remover=staff_remover)
return structure
class Structure(object):
"""Holds page structure detectors."""
def __init__(self,
staff_detector,
beams,
verticals,
connected_components,
image=None,
staff_remover=None):
self.image = image
self.staff_detector = staff_detector
self.beams = beams
self.verticals = verticals
self.connected_components = connected_components
self.staff_remover = staff_remover
def compute(self, session=None, image=None):
"""Computes the structure.
If the staves are already `ComputedStaves` and the verticals are already
`ComputedVerticals`, returns `self`. Otherwise, runs staff detection and/or
verticals detection in the TensorFlow `session`.
Args:
session: The TensorFlow session to use instead of the default session.
image: If non-None, fed as the value of `self.staff_detector.image`.
Returns:
A computed `Structure` object. `staff_detector` and `verticals` hold NumPy
arrays with the result of the TensorFlow graph.
"""
if isinstance(self.staff_detector, staves_base.ComputedStaves):
staff_detector_data = []
else:
staff_detector_data = self.staff_detector.data
if isinstance(self.beams, beams_module.ComputedBeams):
beams_data = []
else:
beams_data = self.beams.data
if isinstance(self.verticals, verticals_module.ComputedVerticals):
verticals_data = []
else:
verticals_data = self.verticals.data
if isinstance(self.connected_components,
components_module.ComputedComponents):
components_data = []
else:
components_data = self.connected_components.data
if not (staff_detector_data or beams_data or verticals_data or
components_data):
return self
if not session:
session = tf.get_default_session()
if image is not None:
feed_dict = {self.staff_detector.image: image}
else:
feed_dict = {}
staff_detector_data, beams_data, verticals_data, components_data = (
session.run(
[staff_detector_data, beams_data, verticals_data, components_data],
feed_dict=feed_dict))
staff_detector_data = staff_detector_data or self.staff_detector.data
staff_detector = staves_base.ComputedStaves(*staff_detector_data)
beams_data = beams_data or self.beams.data
beams = beams_module.ComputedBeams(*beams_data)
verticals_data = verticals_data or self.verticals.data
verticals = verticals_module.ComputedVerticals(*verticals_data)
connected_components = components_module.ConnectedComponents(
*components_data)
return Structure(
staff_detector, beams, verticals, connected_components, image=image)
def is_computed(self):
return (isinstance(self.staff_detector, staves_base.ComputedStaves) and
isinstance(self.beams, beams_module.ComputedBeams) and
isinstance(self.verticals, verticals_module.ComputedVerticals) and
isinstance(self.connected_components,
components_module.ComputedComponents))
@property
def data(self):
return [
self.staff_detector.data, self.beams.data, self.verticals.data,
self.connected_components.data
]
================================================
FILE: moonlight/structure/barlines.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Splits the single StaffSystem into multiple StaffSystems with bars."""
# TODO(ringw): Detect double barlines (with the expected distance between
# them) as one DOUBLE_BAR.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from moonlight.protobuf import musicscore_pb2
from six import moves
class Barlines(object):
"""Staff system and barline detector."""
def __init__(self, structure, close_barline_threshold=None):
barline_valid, self.barline_staff_start, self.barline_staff_end = (
assign_barlines_to_staves(
barline_x=structure.verticals.lines[:, :,
0].mean(axis=1).astype(int),
barline_y0=structure.verticals.lines[:, 0, 1],
barline_y1=structure.verticals.lines[:, 1, 1],
staff_detector=structure.staff_detector))
self.barlines = structure.verticals.lines[barline_valid]
self.close_barline_threshold = (
close_barline_threshold or
np.median(structure.staff_detector.staffline_distance) * 4)
def apply(self, page):
"""Splits the staves in the page into systems with barlines."""
assert len(page.system) == 1
systems_map = dict(
(i, (i, i)) for i in moves.xrange(len(page.system[0].staff)))
for start, end in zip(self.barline_staff_start, self.barline_staff_end):
for staff in moves.xrange(start, end + 1):
start = min(start, systems_map[staff][0])
end = max(end, systems_map[staff][1])
for staff in moves.xrange(start, end + 1):
systems_map[staff] = (start, end)
system_inds = sorted(set(systems_map.values()), key=lambda x: x[0])
staves = page.system[0].staff
systems = [
musicscore_pb2.StaffSystem(staff=staves[start:end + 1])
for (start, end) in system_inds
]
self._assign_barlines(systems)
return musicscore_pb2.Page(system=systems)
def _assign_barlines(self, systems):
"""Assigns each barline to a system.
Args:
systems: The list of StaffSystem messages.
"""
system_start = 0
for system in systems:
system_end = system_start + len(system.staff) - 1
selected_barlines = set()
blacklist_x = self._get_blacklist_x(system)
for i in moves.xrange(len(self.barlines)):
barline_x = self.barlines[i, 0, 0]
start = self.barline_staff_start[i]
end = self.barline_staff_end[i]
if (not blacklist_x[barline_x] and
system_start <= start <= end <= system_end):
# Get the selected barlines which are close enough to the current
# barline that they are probably a duplicate.
close_barlines = [
other_barline for other_barline in selected_barlines
if abs(self.barlines[other_barline, 0, 0] -
barline_x) < self.close_barline_threshold
]
def get_span(barline):
return (self.barline_staff_end[barline] -
self.barline_staff_start[barline])
# Assumes all barlines span the entire staff system.
# Don't add a barline if we've already seen a duplicate unless it
# spans more staves than the currently selected one.
# TODO(ringw): This works for piano scores, but not multi-part
# scores, which have one barline spanning the entire staff system at
# the beginning and then one barline per staff for the following
# measures. Make this more robust.
if (all(end - start >= get_span(other_barline)
for other_barline in selected_barlines) and
all(end - start > get_span(other_barline)
for other_barline in close_barlines)):
selected_barlines.difference_update(close_barlines)
selected_barlines.add(i)
barline_xs = sorted(
self.barlines[barline, 0, 0] for barline in selected_barlines)
system.bar.extend(
musicscore_pb2.StaffSystem.Bar(
x=x, type=musicscore_pb2.StaffSystem.Bar.STANDARD_BAR)
for x in barline_xs)
system_start = system_end + 1
def _get_blacklist_x(self, system):
"""Computes the x coordinates that are blacklisted for barlines.
Barlines cannot be too close to a detected stem, because stems at a certain
vertical position could be confused with barlines spanning a single staff.
Args:
system: The StaffSystem message.
Returns:
A boolean NumPy array. 1D and long enough to contain all of the barlines
on the x axis. True for x coordinates where barlines are disallowed.
"""
staffline_distance = np.median(
[staff.staffline_distance for staff in system.staff]).astype(int)
# Width needed to contain all of the barlines.
barlines_width = (0 if self.barlines.size == 0 else
np.max(self.barlines[:, :, 0]) + 1)
blacklist_x = np.zeros(barlines_width, np.bool)
for staff in system.staff:
for glyph in staff.glyph:
if glyph.HasField('stem'):
stem = glyph.stem
blacklist_start = max(
0,
min(stem.start.x, stem.end.x) - staffline_distance)
blacklist_end = min(
barlines_width,
max(stem.start.x, stem.end.x) + staffline_distance)
blacklist_x[blacklist_start:blacklist_end] = True
return blacklist_x
def assign_barlines_to_staves(barline_x, barline_y0, barline_y1,
staff_detector):
"""Chooses valid barlines for each staff.
Args:
barline_x: 1D array of length N. The barline x coordinates.
barline_y0: 1D array of length N. The barline top y coordinates.
barline_y1: 1D array of length N. The barline bottom y coordinates.
staff_detector: A BaseStaffDetector, for reading the staffline distance.
Returns:
A tuple of:
barline_valid: Boolean array of length N. Whether each of the input barlines
was selected as a valid barline.
barline_staff_start: Boolean array of length `K = barline_valid.sum()`. The
staff index for the top of each valid barline.
barline_staff_end: Boolean array of length K. The staff index for the bottom
of each valid barline. Each entry is >= barline_staff_start.
"""
# To be a barline, the start and end of the line have to be this close to
# the start or end of the staff.
max_distance_to_start_and_end_of_staff = (staff_detector.staffline_distance)
# Compute the closest start and end staves for each vertical line.
staff_starts = (
staff_detector.staves_interpolated_y -
2 * staff_detector.staffline_distance[:, None])
barline_staff_start_distance = np.abs(barline_y0[None, :] -
staff_starts[:, barline_x])
barline_staff_start = np.argmin(barline_staff_start_distance, axis=0)
# Barlines must be at most a single staffline distance away from the
# expected start and end, which are the top line of the start staff and the
# bottom line of the end staff.
barline_valid = np.less_equal(
np.min(barline_staff_start_distance, axis=0),
max_distance_to_start_and_end_of_staff[barline_staff_start])
# Check the closest end staff.
staff_ends = (
staff_detector.staves_interpolated_y +
2 * staff_detector.staffline_distance[:, None])
barline_staff_end_distance = np.abs(barline_y1 - staff_ends[:, barline_x])
barline_staff_end = np.argmin(barline_staff_end_distance, axis=0)
barline_valid &= np.less_equal(
np.min(barline_staff_end_distance, axis=0),
max_distance_to_start_and_end_of_staff[barline_staff_end])
return (barline_valid, barline_staff_start[barline_valid],
barline_staff_end[barline_valid])
================================================
FILE: moonlight/structure/barlines_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for stem detection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from moonlight import structure
from moonlight.protobuf import musicscore_pb2
from moonlight.staves import base as staves_base
from moonlight.structure import barlines as barlines_module
from moonlight.structure import beams
from moonlight.structure import components
from moonlight.structure import verticals
Point = musicscore_pb2.Point # pylint: disable=invalid-name
class BarlinesTest(absltest.TestCase):
def testDummy(self):
# Create a single staff, and a single vertical which is the correct height
# of a stem. The vertical has x = 20 and goes from
struct = structure.Structure(
staff_detector=staves_base.ComputedStaves(
staves=[[[10, 50], [90, 50]], [[11, 150], [91, 150]],
[[10, 250], [90, 250]], [[10, 350], [90, 350]]],
staffline_distance=[12] * 4,
staffline_thickness=2,
staves_interpolated_y=[[50] * 100, [150] * 100, [250] * 100,
[350] * 100]),
beams=beams.ComputedBeams(np.zeros((0, 2, 2))),
connected_components=components.ComputedComponents(np.zeros((0, 5))),
verticals=verticals.ComputedVerticals(lines=[
# Joins the first 2 staves.
[[10, 50 - 12 * 2], [10, 150 + 12 * 2]],
# Another barline, too close to the first one.
[[12, 50 - 12 * 2], [12, 150 + 12 * 2]],
# This barline is far enough, because the second barline was
# skipped.
[[13, 50 - 12 * 2], [13, 150 + 12 * 2]],
# Single staff barlines are skipped.
[[30, 50 - 12 * 2], [30, 50 + 12 * 2]],
[[31, 150 - 12 * 2], [31, 150 + 12 * 2]],
# Too close to a stem.
[[70, 50 - 12 * 2], [70, 50 + 12 * 2]],
# Too short.
[[90, 50 - 12 * 2], [90, 50 + 12 * 2]],
# Another barline which is kept.
[[90, 50 - 12 * 2], [90, 150 + 12 * 2]],
# Staff 1 has no barlines.
# Staff 2 has 2 barlines.
[[11, 350 - 12 * 2], [11, 350 + 12 * 2]],
[[90, 350 - 12 * 2], [90, 350 + 12 * 2]],
]))
barlines = barlines_module.Barlines(struct, close_barline_threshold=3)
# Create a Page with Glyphs.
input_page = musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(staff=[
musicscore_pb2.Staff(
staffline_distance=12,
center_line=[
musicscore_pb2.Point(x=10, y=50),
musicscore_pb2.Point(x=90, y=50)
],
glyph=[
# Stem is close to the last vertical on the first staff, so
# a barline will not be detected there.
musicscore_pb2.Glyph(
type=musicscore_pb2.Glyph.NOTEHEAD_FILLED,
x=60,
y_position=2,
stem=musicscore_pb2.LineSegment(
start=musicscore_pb2.Point(x=72, y=40),
end=musicscore_pb2.Point(x=72, y=80))),
]),
musicscore_pb2.Staff(
staffline_distance=12,
center_line=[
musicscore_pb2.Point(x=10, y=150),
musicscore_pb2.Point(x=90, y=150)
]),
musicscore_pb2.Staff(
staffline_distance=12,
center_line=[
musicscore_pb2.Point(x=10, y=250),
musicscore_pb2.Point(x=90, y=250)
]),
musicscore_pb2.Staff(
staffline_distance=12,
center_line=[
musicscore_pb2.Point(x=10, y=350),
musicscore_pb2.Point(x=90, y=350)
]),
])
])
page = barlines.apply(input_page)
self.assertEqual(3, len(page.system))
self.assertEqual(2, len(page.system[0].staff))
self.assertItemsEqual([10, 13, 90], (bar.x for bar in page.system[0].bar))
self.assertEqual(1, len(page.system[1].staff))
self.assertEqual(0, len(page.system[1].bar))
self.assertEqual(1, len(page.system[2].staff))
self.assertEqual(2, len(page.system[2].bar))
self.assertItemsEqual([11, 90], (bar.x for bar in page.system[2].bar))
if __name__ == "__main__":
absltest.main()
================================================
FILE: moonlight/structure/beam_processor.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adds Beams to notes with intersecting Stems.
First, detects beams that have enough area (black pixel count) proportionate to
their width to count as multiple beams. The beam coordinates are just repeated
in this case for each detected beam, because we don't know specifically where
each individual beam is.
Next, for each stem already attached to a note, we assign any intersecting beam
candidates to the note.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from moonlight.glyphs import glyph_types
from moonlight.protobuf import musicscore_pb2
from moonlight.structure import components
COLUMNS = components.ConnectedComponentsColumns
class BeamProcessor(object):
def __init__(self, structure):
self.beams = _maybe_duplicate_beams(
structure.beams.beams,
np.median(structure.staff_detector.staffline_distance))
def apply(self, page):
"""Adds beams that intersect with note stems to the page.
Beams should intersect with two or more stems. Beams are currently
implemented as a bounding box, so we just see whether that box intersects
with each stem.
Args:
page: A Page message.
Returns:
The same page, with `beam`s added to the `Glyph`s.
"""
for system in page.system:
for staff in system.staff:
# Extend the beams by the staffline distance on either side. Beams may
# end immediately at a stem, so give an extra allowance for that stem.
extended_beams = self.beams.copy()
extended_beams[:, COLUMNS.X0] -= staff.staffline_distance
extended_beams[:, COLUMNS.X1] += staff.staffline_distance
for glyph in staff.glyph:
if glyph_types.is_beamed_notehead(glyph) and glyph.HasField('stem'):
xs = [glyph.stem.start.x, glyph.stem.end.x]
ys = [glyph.stem.start.y, glyph.stem.end.y]
stem_bounding_box = np.asarray([[min(*xs), min(*ys)],
[max(*xs), max(*ys)]])
overlapping_beams = _get_overlapping_beams(stem_bounding_box,
extended_beams)
glyph.beam.extend(
musicscore_pb2.LineSegment(
start=musicscore_pb2.Point(
x=beam[COLUMNS.X0], y=beam[COLUMNS.Y0]),
end=musicscore_pb2.Point(
x=beam[COLUMNS.X1], y=beam[COLUMNS.Y1]))
for beam in overlapping_beams)
return page
def _get_overlapping_beams(stem, beams):
"""Filters beams that overlap with the stem.
Args:
stem: NumPy array `((x0, y0), (x1, y1))` representing the stem line.
beams: NumPy array of shape `(num_beams, 2, 2)`. The line segment for every
candidate beam.
Returns:
Filtered beams of shape `(num_filtered_beams, 2, 2)`. All of the beams which
intersect with the given stem.
"""
# The horizontal and vertical intervals of the stem line must match the
# intervals that the beam covers. Broadcast the single stem against all of
# the beams.
x_overlaps = _do_intervals_overlap(stem[None, :, 0],
beams[:, [COLUMNS.X0, COLUMNS.X1]])
y_overlaps = _do_intervals_overlap(stem[None, :, 1],
beams[:, [COLUMNS.Y0, COLUMNS.Y1]])
return beams[np.logical_and(x_overlaps, y_overlaps)]
def _maybe_duplicate_beams(beams, staffline_distance):
"""Determines whether each candidate beam actually contains multiple beams.
Beams are normally separated by a narrow space, but sometimes they can blur
together. Example: https://imgur.com/2ompQAz.png
The number of black pixels in a single beam is proportional to its width. If
the total area of the component is a multiple of the expected area, repeat the
beam to count as multiple beams.
Args:
beams: The connected component array with shape (N, 5). The values in the
columns are determined in `components.ConnectedComponentsColumns`.
staffline_distance: The scalar staffline distance (median from all staves).
Returns:
The beams array, possibly with some beam candidates repeated along the 0th
axis.
"""
def _estimate_num_beams(beam):
width = beam[COLUMNS.X1] - beam[COLUMNS.X0]
# Beams appear to typically be slightly shorter than the staffline distance.
estimated_area_per_beam = width * staffline_distance * 0.75
return max(1, np.round(beam[COLUMNS.SIZE] / estimated_area_per_beam))
estimated_num_beams = list(map(_estimate_num_beams, beams))
return np.repeat(beams, estimated_num_beams, axis=0)
def _do_intervals_overlap(intervals_a, intervals_b):
"""Whether the intervals overlap, pairwise.
intervals_a and intervals_b should both have the same shape
`(num_intervals, 2)`. For each interval from each argument, returns a boolean
of whether the numeric intervals overlap.
Args:
intervals_a: Numeric NumPy array of shape `(num_intervals, 2)`.
intervals_b: Numeric NumPy array of shape `(num_intervals, 2)`.
Returns:
Boolean NumPy array of length `num_intervals`.
"""
def contained(points, intervals):
return np.logical_and(
np.less_equal(intervals[:, 0], points),
np.less_equal(points, intervals[:, 1]))
return np.logical_or(
np.logical_or(
contained(intervals_a[:, 0], intervals_b),
contained(intervals_a[:, 1], intervals_b)),
np.logical_or(
contained(intervals_b[:, 0], intervals_a),
contained(intervals_b[:, 1], intervals_a)))
================================================
FILE: moonlight/structure/beams.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detects note beams.
Beams are long, very thick, horizontal or diagonal lines that may intersect with
the staves. To detect them, we use staff removal followed by extra binary
erosion, in case the staves are not completely removed and still have extra
black pixels around a beam. We then find all of the connected components,
because each beam should now be detached from the stem, staff, and (typically)
other beams. We filter beams by minimum width. Further processing and assignment
of stems to beams is done in `beam_processor.py`.
Each beam halves the duration of each note it is atteched to by a stem.
"""
# TODO(ringw): Make Hough line segments more robust, and then use them here
# instead of connected components.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.structure import components
from moonlight.vision import morphology
COLUMNS = components.ConnectedComponentsColumns
class Beams(object):
"""Note beam detector."""
def __init__(self, staff_remover, threshold=127):
staff_detector = staff_remover.staff_detector
image = morphology.binary_erosion(
tf.less(staff_remover.remove_staves, threshold),
staff_detector.staffline_thickness)
beams = components.get_component_bounds(image)
staffline_distance = tf.cond(
# pyformat: disable
tf.greater(tf.shape(staff_detector.staves)[0], 0),
lambda: tf.reduce_mean(staff_detector.staffline_distance),
lambda: tf.constant(0, tf.int32))
min_length = 2 * staffline_distance
keep_beam = tf.greater_equal(beams[:, COLUMNS.X1] - beams[:, COLUMNS.X0],
min_length)
keep_beam.set_shape([None])
self.beams = tf.boolean_mask(beams, keep_beam)
self.data = [self.beams]
class ComputedBeams(object):
"""Holder for the computed beams NumPy array."""
def __init__(self, beams):
self.beams = np.asarray(beams, np.int32)
================================================
FILE: moonlight/structure/components.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Connected component analysis.
Connected components will be used to detect solid-ish elements or blobs on the
score (e.g. beams, dots, and whole/half rests).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
import tensorflow as tf
from tensorflow.contrib import image as contrib_image
class ConnectedComponentsColumns(enum.IntEnum):
"""The field names of the connected components 2D array columns."""
X0 = 0
Y0 = 1
X1 = 2
Y1 = 3
SIZE = 4
def from_staff_remover(staff_remover, threshold=127):
with tf.name_scope('from_staff_remover'):
image = tf.less(staff_remover.remove_staves, threshold)
return ConnectedComponents(get_component_bounds(image))
class ConnectedComponents(object):
"""Holds the connected components on the staves removed image."""
def __init__(self, components):
with tf.name_scope('ConnectedComponents'):
self.components = components
self.data = [self.components]
class ComputedComponents(object):
"""Holds the computed NumPy array of connected components."""
def __init__(self, components):
self.components = components
self.data = [components]
def get_component_bounds(image):
"""Returns the bounding box of each connected component in `image`.
Connected components are segments of adjacent True pixels in the image.
Args:
image: A 2D boolean image tensor.
Returns:
A tensor of shape (num_components, 5), where each row represents a connected
component of the image as `(x0, y0, x1, y1, size)`. `size` is the count
of True pixels in the component, and the coordinates are the top left
and bottom right corners of the bounding box.
"""
with tf.name_scope('get_component_bounds'):
components = contrib_image.connected_components(image)
num_components = tf.reduce_max(components) + 1
width = tf.shape(image)[1]
height = tf.shape(image)[0]
xs, ys = tf.meshgrid(tf.range(width), tf.range(height))
component_x0 = _unsorted_segment_min(xs, components, num_components)[1:]
component_x1 = tf.unsorted_segment_max(xs, components, num_components)[1:]
component_y0 = _unsorted_segment_min(ys, components, num_components)[1:]
component_y1 = tf.unsorted_segment_max(ys, components, num_components)[1:]
component_size = tf.bincount(components)[1:]
return tf.stack([
component_x0, component_y0, component_x1, component_y1, component_size
],
axis=1)
def _unsorted_segment_min(data, segment_ids, num_segments):
return -tf.unsorted_segment_max(-data, segment_ids, num_segments)
================================================
FILE: moonlight/structure/components_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for connected component analysis."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.structure import components
class ComponentsTest(tf.test.TestCase):
def testComponents(self):
arr = [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
[1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]] # pyformat: disable
component_bounds_t = components.get_component_bounds(tf.cast(arr, tf.bool))
with self.test_session():
component_bounds = component_bounds_t.eval()
self.assertAllEqual(
component_bounds,
# x0, y0, x1, y1, size
[[5, 0, 5, 0, 1], [0, 1, 4, 5, 16], [8, 1, 10, 3, 5], [12, 1, 12, 2, 2],
[2, 3, 2, 3, 1], [6, 4, 6, 4, 1]])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/structure/section_barlines.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detects section barlines, which are much thicker than normal barlines.
Section barlines appear as connected components which span the height of the
system, and are not too thick. They may have 2 repeat dots on one or both sides
of each staff (at y positions -1 and 1), which affect the barline type.
"""
# TODO(ringw): Get repeat dots from the components and adjust the barline
# type accordingly. Currently, assume all thick barlines are END_BAR.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from moonlight.protobuf import musicscore_pb2
from moonlight.structure import barlines
from moonlight.structure import components as components_module
Bar = musicscore_pb2.StaffSystem.Bar # pylint: disable=invalid-name
COLUMNS = components_module.ConnectedComponentsColumns
class SectionBarlines(object):
"""Reads the connected components, and adds thick barlines to the page."""
def __init__(self, structure):
self.components = structure.connected_components.components
self.staff_detector = structure.staff_detector
def apply(self, page):
"""Detects thick section barlines from the connected components.
These should be tall components that start and end near the start and end
of two (possibly different) staves. We use the standard barlines logic to
assign components to the nearest start and end staff. We filter for
candidate barlines, whose start and end are sufficiently close to the
expected values. We then filter again by whether the component width is
within the expected values for section barlines.
For each staff system, we take the section barlines that match exactly that
system's staves. Any standard barlines that are too close to a new section
barline are removed, and we merge the existing standard barlines with the
new section barlines.
Args:
page: A Page message.
Returns:
The same Page message, with new section barlines added.
"""
component_center_x = np.mean(
self.components[:, [COLUMNS.X0, COLUMNS.X1]], axis=1).astype(int)
# Take section barline candidates, whose start and end y values are close
# enough to the staff start and end ys.
component_is_candidate, candidate_start_staff, candidate_end_staff = (
barlines.assign_barlines_to_staves(
barline_x=component_center_x,
barline_y0=self.components[:, COLUMNS.Y0],
barline_y1=self.components[:, COLUMNS.Y1],
staff_detector=self.staff_detector))
candidates = self.components[component_is_candidate]
candidate_center_x = component_center_x[component_is_candidate]
del component_center_x
# Filter again by the expected section barline width.
component_width = candidates[:, COLUMNS.X1] - candidates[:, COLUMNS.X0]
component_width_ok = np.logical_and(
self._section_min_width() <= component_width,
component_width <= self._section_max_width(candidate_start_staff))
candidates = candidates[component_width_ok]
candidate_center_x = candidate_center_x[component_width_ok]
candidate_start_staff = candidate_start_staff[component_width_ok]
candidate_end_staff = candidate_end_staff[component_width_ok]
# For each existing staff system, consider only the candidates that match
# exactly the system's start and end staves.
start_staff = 0
for system in page.system:
staffline_distance = np.median(
[staff.staffline_distance for staff in system.staff]).astype(int)
candidate_covers_staff_system = np.logical_and(
candidate_start_staff == start_staff,
candidate_end_staff + 1 == start_staff + len(system.staff))
# Calculate the x coordinates of all section barlines to keep.
section_bar_x = candidate_center_x[candidate_covers_staff_system]
# Extract the existing bar x coordinates and types for merging.
existing_bar_type = {bar.x: bar.type for bar in system.bar}
existing_bars = np.asarray([bar.x for bar in system.bar])
# Merge the existing barlines and section barlines.
if existing_bars.size and section_bar_x.size:
# Filter the existing bars by whether they are far enough from a new
# section barline. Section barlines override the existing standard
# barlines.
existing_bars_ok = np.greater(
np.min(
np.abs(existing_bars[:, None] - section_bar_x[None, :]),
axis=1), staffline_distance * 4)
existing_bars = existing_bars[existing_bars_ok]
# Merge the existing barlines which we kept, and the new section barlines
# (which are assumed to be type END_BAR), in sorted order.
bars = sorted(
[Bar(x=x, type=existing_bar_type[x]) for x in existing_bars] +
[Bar(x=x, type=Bar.END_BAR) for x in section_bar_x],
key=lambda bar: bar.x)
# Update the staff system.
system.ClearField('bar')
system.bar.extend(bars)
start_staff += len(system.staff)
return page
def _section_min_width(self):
return self.staff_detector.staffline_thickness * 3
def _section_max_width(self, staff_index):
return self.staff_detector.staffline_distance[staff_index] * 2
class MergeStandardAndBeginRepeatBars(object):
"""Detects a begin repeat at the beginning of the staff system.
Typically, a begin repeat bar on a new line will be preceded by a standard
barline, clef, and key signature. We can override a standard bar with a
section bar if they are close together, but this distance is typically closer
than the two bars are in this case.
We want the two bars to be replaced by a single begin repeat bar where we
actually found the first bar, because we want the clef, key signature, and
notes to be a single measure.
Because we don't yet detect repeat dots, and all non-STANDARD barlines are
detected as END_BAR, we accept any non-STANDARD barlines for the second bar.
"""
def __init__(self, structure):
self.staff_detector = structure.staff_detector
def apply(self, page):
for system in page.system:
if (len(system.bar) > 1 and system.bar[0].type == Bar.STANDARD_BAR and
system.bar[1].type != Bar.STANDARD_BAR):
staffline_distance = np.median(
[staff.staffline_distance for staff in system.staff])
if system.bar[1].x - system.bar[0].x < staffline_distance * 12:
system.bar[0].type = system.bar[1].type
del system.bar[1]
return page
================================================
FILE: moonlight/structure/stems.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stem detection.
A stem detector takes stem candidates from `ColumnBasedVerticals`, which are
vertical lines with a height close to the expected height of a stem.
The distance is computed from each notehead to each stem, and the notehead is
assigned to the closest stem if the distance is below a threshold. The distance
is based on the coordinate where the notehead would ideally lie if it belongs
to the stem. First, the glyph y is clamped to the range of the stem, because the
center of the notehead should not be above or below the stem. Next, if the glyph
is left of the stem, the ideal x is a constant left of the stem, and similar if
it is right of the stem. This is because the left or right side of the notehead
should touch the stem, so they are ideally a fixed distance away.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from moonlight.glyphs import glyph_types
from moonlight.protobuf import musicscore_pb2
# The minimum height of a stem, as a multiple of the staffline distance.
_MIN_STEM_HEIGHT_STAFFLINE_DISTANCE = 2.5
# The expected horizontal distance from the notehead to the stem.
_STEM_NOTEHEAD_HORIZONTAL_STAFFLINE_DISTANCE = 0.5
# The maximum Euclidean distance from a notehead to its ideal position for a
# given stem (see module docstring).
_STEM_NOTEHEAD_DISTANCE_STAFFLINE_DISTANCE = 0.5
class Stems(object):
"""Stem detector."""
def __init__(self, structure):
"""Constructs the stem detector.
Args:
structure: A computed structure.
Raises:
ValueError: If structure.is_computed() is false.
"""
if not structure.is_computed():
raise ValueError("Run Structure.compute() before passing it here")
self.staff_detector = structure.staff_detector
staffline_distance = np.mean(self.staff_detector.staffline_distance)
verticals = structure.verticals
self.stem_candidates = _get_stem_candidates(staffline_distance, verticals)
def apply(self, page):
"""Detects stems on the page.
Using `self.stem_candidates`, finds verticals that align with a notehead
glyph, and adds the stems.
Args:
page: The Page message.
Returns:
The same page, updated with stems.
"""
for system in page.system:
for staff, staff_ys in zip(system.staff,
self.staff_detector.staves_interpolated_y):
allowed_distance = np.multiply(
_STEM_NOTEHEAD_DISTANCE_STAFFLINE_DISTANCE,
staff.staffline_distance)
expected_horizontal_distance = np.multiply(
_STEM_NOTEHEAD_HORIZONTAL_STAFFLINE_DISTANCE,
staff.staffline_distance)
for glyph in staff.glyph:
if glyph_types.is_stemmed_notehead(glyph):
glyph_y = (
staff_ys[glyph.x] -
glyph.y_position * staff.staffline_distance / 2.0)
# Compute the ideal coordinates for the glyph to be assigned to each
# stem.
# Clip the glyph_y to the stem start and end y to get the ideal y.
ideal_y = np.clip(glyph_y, self.stem_candidates[:, 0, 1],
self.stem_candidates[:, 1, 1])
# If the glyph is left of the stem, subtract the expected distance
# from the stem x; otherwise, add it.
ideal_x = self.stem_candidates[:, 0, 0] + np.where(
glyph.x < self.stem_candidates[:, 0, 0],
-expected_horizontal_distance, expected_horizontal_distance)
stem_distance = np.linalg.norm(
np.c_[ideal_x - glyph.x, ideal_y - glyph_y], axis=1)
stem = np.argmin(stem_distance)
if stem_distance[stem] <= allowed_distance:
stem_coords = self.stem_candidates[stem]
glyph.stem.CopyFrom(
musicscore_pb2.LineSegment(
start=musicscore_pb2.Point(
x=stem_coords[0, 0], y=stem_coords[0, 1]),
end=musicscore_pb2.Point(
x=stem_coords[1, 0], y=stem_coords[1, 1])))
return page
def _get_stem_candidates(staffline_distance, verticals):
heights = verticals.lines[:, 1, 1] - verticals.lines[:, 0, 1]
return verticals.lines[np.greater_equal(
heights, staffline_distance * _MIN_STEM_HEIGHT_STAFFLINE_DISTANCE)]
================================================
FILE: moonlight/structure/stems_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for stem detection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
from moonlight import structure
from moonlight.protobuf import musicscore_pb2
from moonlight.staves import base as staves_base
from moonlight.structure import beams
from moonlight.structure import components
from moonlight.structure import stems as stems_module
from moonlight.structure import verticals
Point = musicscore_pb2.Point # pylint: disable=invalid-name
class StemsTest(absltest.TestCase):
def testDummy(self):
# Create a single staff, and a single vertical which is the correct height
# of a stem. The vertical has x = 20 and goes from
struct = structure.Structure(
staff_detector=staves_base.ComputedStaves(
staves=[[[10, 50], [90, 50]]],
staffline_distance=[12],
staffline_thickness=2,
staves_interpolated_y=[[50] * 100]),
beams=beams.ComputedBeams(np.zeros((0, 2, 2))),
verticals=verticals.ComputedVerticals(
lines=[[[20, 38], [20, 38 + 12 * 4]]]),
connected_components=components.ComputedComponents([]))
stems = stems_module.Stems(struct)
# Create a Page with Glyphs.
input_page = musicscore_pb2.Page(system=[
musicscore_pb2.StaffSystem(staff=[
musicscore_pb2.Staff(
staffline_distance=12,
center_line=[
musicscore_pb2.Point(x=10, y=50),
musicscore_pb2.Point(x=90, y=50)
],
glyph=[
# Cannot have a stem because it's a flat.
musicscore_pb2.Glyph(
type=musicscore_pb2.Glyph.FLAT, x=15, y_position=-1),
# On the right side of the stem, the correct distance away.
musicscore_pb2.Glyph(
type=musicscore_pb2.Glyph.NOTEHEAD_FILLED,
x=25,
y_position=-1),
# Too high for the stem.
musicscore_pb2.Glyph(
type=musicscore_pb2.Glyph.NOTEHEAD_FILLED,
x=25,
y_position=4),
# Too far right from the stem.
musicscore_pb2.Glyph(
type=musicscore_pb2.Glyph.NOTEHEAD_FILLED,
x=35,
y_position=-1),
])
])
])
page = stems.apply(input_page)
self.assertFalse(page.system[0].staff[0].glyph[0].HasField("stem"))
self.assertTrue(page.system[0].staff[0].glyph[1].HasField("stem"))
self.assertEqual(
page.system[0].staff[0].glyph[1].stem,
musicscore_pb2.LineSegment(
start=Point(x=20, y=38), end=Point(x=20, y=38 + 12 * 4)))
self.assertFalse(page.system[0].staff[0].glyph[2].HasField("stem"))
self.assertFalse(page.system[0].staff[0].glyph[3].HasField("stem"))
if __name__ == "__main__":
absltest.main()
================================================
FILE: moonlight/structure/structure_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for structure computation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import numpy as np
import tensorflow as tf
from moonlight import image as image_module
from moonlight import structure
class StructureTest(tf.test.TestCase):
def testCompute(self):
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
image = image_module.decode_music_score_png(tf.read_file(filename))
struct = structure.create_structure(image)
with self.test_session():
struct = struct.compute()
self.assertEqual(np.int32, struct.staff_detector.staves.dtype)
# Expected number of staves for the corpus image.
self.assertEqual((12, 2, 2), struct.staff_detector.staves.shape)
self.assertEqual(np.int32, struct.verticals.lines.dtype)
self.assertEqual(3, struct.verticals.lines.ndim)
self.assertEqual((2, 2), struct.verticals.lines.shape[1:])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/structure/verticals.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detects vertical lines using the runs in each column of the image.
After the TensorFlow graph is run, vertical lines are classified as stems or
barlines.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.util import functional_ops
from moonlight.util import memoize
from moonlight.util import segments
from moonlight.vision import images
from moonlight.vision import morphology
# Join gaps in vertical lines using a small gap (relative to the difference
# between stafflines).
_DEFAULT_MAX_GAP_STAFFLINE_DISTANCE = frozenset((0.1, 0.2, 0.5))
# Beams and barlines should be at least the height of a staff (4 stafflines).
# Use a minimum of 3 * staffline distance.
_DEFAULT_MIN_LENGTH_STAFFLINE_DISTANCE = 2.5
class ColumnBasedVerticals(object):
"""Vertical line segment detector.
Does not resolve duplicates across multiple columns, or the same line that is
detected multiple times with different `max_gap` settings. The user should
expect to find multiple detected lines that correspond to the same physical
line, and choose the line that seems the most likely for a given purpose.
Attributes:
staff_detector: An instance of `staves.base.BaseStaffDetector`.
image: The uint8 image tensor.
threshold: The image threshold. int.
thresholded_image: Whether each pixel of the image is black.
max_gap: Multiple values for the maximum gap allowed in a line segment, in
pixels. Tensor (1D) of ints.
min_length: The minimum length of a line segment, in pixels. int.
"""
def __init__(
self,
staff_detector,
threshold=127,
max_gap_staffline_distance=_DEFAULT_MAX_GAP_STAFFLINE_DISTANCE,
min_length_staffline_distance=_DEFAULT_MIN_LENGTH_STAFFLINE_DISTANCE):
self.staff_detector = staff_detector
self.image = staff_detector.image
self.threshold = threshold
thresholded_image = tf.less(self.image, threshold)
self.filtered_image = morphology.binary_dilation(
_horizontal_filter(thresholded_image,
staff_detector.staffline_thickness), 1)
staffline_distance = tf.reduce_mean(staff_detector.staffline_distance)
# Deterministically convert max_gap_staffline_distance to a list.
# We use a frozenset so there is no risk of mutating the default argument.
self.max_gap = tf.to_int32(
tf.round(
tf.to_float(staffline_distance) *
sorted(max_gap_staffline_distance)))
self.min_length = tf.to_int32(
tf.round(
tf.to_float(staffline_distance) * min_length_staffline_distance))
@property
@memoize.MemoizedFunction
def lines(self):
"""The vertical lines.
Returns:
int32 tensor of shape (num_lines, 2, 2), storing lines as
((start_x, start_y), (end_x, end_y)).
"""
columns = tf.range(tf.shape(self.image)[1])
def map_max_gap(max_gap):
"""Process all columns with the given value for max_gap."""
return functional_ops.flat_map_fn(
lambda column: self._verticals_in_column(max_gap, column), columns)
return functional_ops.flat_map_fn(map_max_gap, self.max_gap)
def _verticals_in_column(self, max_gap, column):
"""Gets the verticals from a single column.
Args:
max_gap: The scalar max_gap value to use. int tensor.
column: The scalar column index. int tensor.
Returns:
int32 tensor of shape (num_lines_in_column, 2, 2). All start_x and end_x
values are equal to column.
"""
image_column = self.filtered_image[:, column]
run_starts, run_lengths = segments.true_segments_1d(
image_column,
mode=segments.SegmentsMode.STARTS,
max_gap=max_gap,
min_length=self.min_length)
num_runs = tf.shape(run_starts)[0]
# x is the same for all runs in the column.
x = tf.fill([num_runs], column)
y0 = run_starts
y1 = run_starts + run_lengths - 1
return tf.stack([
tf.stack([x, y0], axis=1),
tf.stack([x, y1], axis=1),
],
axis=1)
@property
def data(self):
"""Returns the list of verticals tensors to be computed.
Returns:
A list of Tensors.
"""
return [self.lines]
def _horizontal_filter(image, staffline_thickness):
"""The vertical lines horizontal filter.
A black pixel in a vertical line must have a white pixel
`2 * staffline_thickness` pixels away, on the left and/or right.
Args:
image: 2D thresholded boolean image with black pixels as True.
staffline_thickness: The estimated staffline thickness, in pixels.
Returns:
The filtered image.
"""
# images.translate() requires a float image. Unlike the convention (255 or 1.0
# for white), the image is already thresholded here, so 1.0 is black and 0.0
# is white.
float_image = tf.cast(image, tf.float32)
gap = staffline_thickness * 2
return tf.logical_and(
image,
tf.logical_or(
tf.equal(images.translate(float_image, -gap, 0), 0),
tf.equal(images.translate(float_image, gap, 0), 0)))
class ComputedVerticals(object):
"""Computed vertical lines holder.
The result of `ColumnBasedVerticals.compute()`. Holds a NumPy array with the
vertical lines.
"""
def __init__(self, lines):
self.lines = np.array(lines)
@property
def data(self):
return [self.lines]
================================================
FILE: moonlight/structure/verticals_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for vertical line detection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.staves import testing
from moonlight.structure import verticals
class ColumnBasedVerticalsTest(tf.test.TestCase):
def testVerticalLines_singleColumn(self):
image = np.zeros((20, 4), bool)
image[5:10, 0] = True
image[11:15, 1] = True
image[:5, 3] = True
staff_detector = testing.FakeStaves(
tf.constant(np.where(image, 0, 255), tf.uint8),
staves_t=None,
staffline_distance_t=[1],
staffline_thickness_t=0.5)
verticals_detector = verticals.ColumnBasedVerticals(
staff_detector,
max_gap_staffline_distance=[1],
min_length_staffline_distance=3)
lines_t = verticals_detector.lines
with self.test_session():
lines = [line.tolist() for line in lines_t.eval()]
# Start is dilated by 1 pixel since the start is actually in this column.
self.assertIn([[0, 4], [0, 14]], lines)
# Only the end is contained in this column, so it is dilated but the start
# is not.
self.assertIn([[1, 5], [1, 15]], lines)
self.assertNotIn([[2, 4], [2, 14]], lines)
self.assertNotIn([[2, 5], [2, 15]], lines)
self.assertIn([[3, 0], [3, 5]], lines)
# Out of bounds.
self.assertNotIn([[4, 0], [4, 5]], lines)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/testdata/BUILD
================================================
# Description:
# Music scores used for inference and testing the OMR pipeline.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
filegroup(
name = "images",
srcs = ["IMSLP00747-000.png"],
)
filegroup(
name = "musicxml",
srcs = glob(["*.xml"]),
)
================================================
FILE: moonlight/testdata/IMSLP00747-000.LICENSE.md
================================================
The image was obtained from IMSLP:
https://imslp.org/wiki/Special:ReverseLookup/747
It was published in 1853 and is in the public domain worldwide.
================================================
FILE: moonlight/testdata/IMSLP00747.golden.LICENSE.md
================================================
The score was transcribed by the Moonlight authors. The MusicXML transcription
is released under the Apache License, version 2.0, as with all code.
================================================
FILE: moonlight/testdata/IMSLP00747.golden.xml
================================================
Invention No. 1
J. S. Bach
MuseScore 2.1.0
2017-09-18
Dan Ringwalt (ringwalt@google.com)
7.05556
40
1683.36
1190.88
56.6929
56.6929
56.6929
113.386
56.6929
56.6929
56.6929
113.386
Invention No. 1
J. S. Bach
Piano
Pno.
Piano
1
1
78.7402
0
21.00
-0.00
170.00
65.00
8
0
4
4
2
G
2
F
4
quarter
80
1
2
1
16th
1
C
4
2
1
16th
up
1
begin
begin
D
4
2
1
16th
up
1
continue
continue
E
4
2
1
16th
up
1
end
end
F
4
2
1
16th
up
1
begin
begin
D
4
2
1
16th
up
1
continue
continue
E
4
2
1
16th
up
1
continue
continue
C
4
2
1
16th
up
1
end
end
G
4
4
1
eighth
up
1
begin
C
5
4
1
eighth
up
1
continue
B
4
4
1
eighth
up
1
continue
C
5
4
1
eighth
up
1
end
32
16
5
half
2
2
5
16th
2
C
3
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
end
end
F
3
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
continue
continue
C
3
2
5
16th
down
2
end
end
D
5
2
1
16th
down
1
begin
begin
G
4
2
1
16th
down
1
continue
continue
A
4
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
end
end
C
5
2
1
16th
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
B
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
end
end
D
5
4
1
eighth
down
1
begin
G
5
4
1
eighth
down
1
continue
F
5
4
1
eighth
down
1
continue
G
5
4
1
eighth
down
1
end
32
G
3
4
5
eighth
up
2
begin
G
2
4
5
eighth
up
2
end
8
5
quarter
2
2
5
16th
2
G
3
2
5
16th
down
2
begin
begin
A
3
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
end
end
C
4
2
5
16th
down
2
begin
begin
A
3
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
continue
continue
G
3
2
5
16th
down
2
end
end
21.00
-0.00
137.07
65.00
E
5
2
1
16th
down
1
begin
begin
A
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
end
end
E
5
2
1
16th
down
1
begin
begin
G
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
A
5
2
1
16th
down
1
end
end
G
5
2
1
16th
down
1
begin
begin
F
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
C
5
2
1
16th
down
1
begin
begin
E
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
end
end
32
C
4
4
5
eighth
down
2
begin
B
3
4
5
eighth
down
2
continue
C
4
4
5
eighth
down
2
continue
D
4
4
5
eighth
down
2
end
E
4
4
5
eighth
down
2
begin
G
3
4
5
eighth
down
2
continue
A
3
4
5
eighth
down
2
continue
B
3
4
5
eighth
down
2
end
E
5
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
end
end
A
4
2
1
16th
down
1
begin
begin
C
5
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
C
5
2
1
16th
up
1
begin
begin
B
4
2
1
16th
up
1
continue
continue
A
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
end
end
F
1
4
2
1
16th
sharp
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
continue
continue
B
4
2
1
16th
up
1
end
end
32
C
4
4
5
eighth
down
2
begin
E
3
4
5
eighth
down
2
continue
F
1
3
4
5
eighth
sharp
down
2
continue
G
3
4
5
eighth
down
2
end
A
3
4
5
eighth
down
2
begin
B
3
4
5
eighth
down
2
end
C
4
8
5
quarter
down
2
21.00
-0.00
137.07
65.00
A
4
4
1
eighth
up
1
begin
D
4
4
1
eighth
up
1
end
C
5
6
1
eighth
down
1
begin
D
5
2
1
16th
down
1
end
backward hook
B
4
2
1
16th
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
continue
continue
F
1
4
2
1
16th
sharp
up
1
end
end
E
4
2
1
16th
up
1
begin
begin
G
4
2
1
16th
up
1
continue
continue
F
1
4
2
1
16th
up
1
continue
continue
A
4
2
1
16th
up
1
end
end
32
C
4
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
continue
continue
F
1
3
2
5
16th
sharp
down
2
end
end
G
3
2
5
16th
down
2
begin
begin
E
3
2
5
16th
down
2
continue
continue
F
1
3
2
5
16th
down
2
continue
continue
D
3
2
5
16th
down
2
end
end
G
3
4
5
eighth
down
2
begin
B
2
4
5
eighth
down
2
continue
C
3
4
5
eighth
down
2
continue
D
3
4
5
eighth
down
2
end
G
4
2
1
16th
up
1
begin
begin
B
4
2
1
16th
up
1
continue
continue
A
4
2
1
16th
up
1
continue
continue
C
5
2
1
16th
up
1
end
end
B
4
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
D
5
2
1
16th
down
1
begin
begin
B
4
1
1
32nd
down
1
continue
continue
begin
C
5
1
1
32nd
down
1
continue
continue
end
D
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
end
end
B
4
4
1
eighth
up
1
begin
A
4
2
1
16th
up
1
continue
begin
G
4
2
1
16th
up
1
end
end
32
E
3
4
5
eighth
down
2
begin
F
1
3
4
5
eighth
sharp
down
2
continue
G
3
4
5
eighth
down
2
continue
E
3
4
5
eighth
down
2
end
B
2
6
5
eighth
up
2
begin
C
3
2
5
16th
up
2
end
backward hook
D
3
4
5
eighth
up
2
begin
D
2
4
5
eighth
up
2
end
21.00
-0.00
137.07
65.00
G
4
4
1
eighth
up
1
4
1
eighth
1
8
1
quarter
1
2
1
16th
1
G
4
2
1
16th
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
B
4
2
1
16th
up
1
end
end
C
5
2
1
16th
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
B
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
end
end
32
2
5
16th
2
G
2
2
5
16th
up
2
begin
begin
A
2
2
5
16th
up
2
continue
continue
B
2
2
5
16th
up
2
end
end
C
3
2
5
16th
up
2
begin
begin
A
2
2
5
16th
up
2
continue
continue
B
2
2
5
16th
up
2
continue
continue
G
2
2
5
16th
up
2
end
end
D
3
4
5
eighth
down
2
begin
G
3
4
5
eighth
down
2
continue
F
1
3
4
5
eighth
sharp
down
2
continue
G
3
4
5
eighth
down
2
end
F
1
4
4
1
eighth
sharp
up
1
4
1
eighth
1
8
1
quarter
1
2
1
16th
1
A
4
2
1
16th
down
1
begin
begin
B
4
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
end
end
D
5
2
1
16th
down
1
begin
begin
B
4
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
continue
continue
A
4
2
1
16th
down
1
end
end
32
A
3
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
continue
continue
F
1
3
2
5
16th
sharp
down
2
end
end
G
3
2
5
16th
down
2
begin
begin
E
3
2
5
16th
down
2
continue
continue
F
1
3
2
5
16th
down
2
continue
continue
D
3
2
5
16th
down
2
end
end
A
3
4
5
eighth
down
2
begin
D
4
4
5
eighth
down
2
continue
C
4
4
5
eighth
down
2
continue
D
4
4
5
eighth
down
2
end
21.00
-0.00
137.07
65.00
B
4
4
1
eighth
down
1
4
1
eighth
1
8
1
quarter
1
2
1
16th
1
D
5
2
1
16th
down
1
begin
begin
C
5
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
end
end
A
4
2
1
16th
down
1
begin
begin
C
5
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
32
G
3
2
5
16th
up
2
begin
begin
G
2
G
4
2
5
16th
up
2
continue
continue
F
4
2
5
16th
up
2
continue
continue
E
4
2
5
16th
up
2
end
end
D
4
2
5
16th
up
2
begin
begin
F
4
2
5
16th
up
2
continue
continue
E
4
2
5
16th
up
2
continue
continue
G
4
2
5
16th
up
2
end
end
F
4
4
5
eighth
up
2
begin
E
4
4
5
eighth
up
2
continue
F
4
4
5
eighth
up
2
continue
D
4
4
5
eighth
up
2
end
C
5
4
1
eighth
down
1
4
1
eighth
1
8
1
quarter
1
2
1
16th
1
E
5
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
end
end
B
4
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
C
1
5
2
1
16th
sharp
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
32
E
4
2
5
16th
up
2
begin
begin
A
4
2
5
16th
up
2
continue
continue
G
4
2
5
16th
up
2
continue
continue
F
4
2
5
16th
up
2
end
end
E
4
2
5
16th
up
2
begin
begin
G
4
2
5
16th
up
2
continue
continue
F
4
2
5
16th
up
2
continue
continue
A
4
2
5
16th
up
2
end
end
G
4
4
5
eighth
up
2
begin
F
4
4
5
eighth
up
2
continue
G
4
4
5
eighth
up
2
continue
E
4
4
5
eighth
up
2
end
21.00
-0.00
70.00
65.00
D
5
4
1
eighth
down
1
begin
C
1
5
4
1
eighth
sharp
down
1
continue
D
5
4
1
eighth
down
1
continue
E
5
4
1
eighth
down
1
end
F
5
4
1
eighth
down
1
begin
A
4
4
1
eighth
down
1
continue
B
4
4
1
eighth
natural
down
1
continue
C
1
5
4
1
eighth
down
1
end
32
F
4
2
5
16th
up
2
begin
begin
B
-1
4
2
5
16th
flat
up
2
continue
continue
A
4
2
5
16th
up
2
continue
continue
G
4
2
5
16th
up
2
end
end
F
4
2
5
16th
up
2
begin
begin
A
4
2
5
16th
up
2
continue
continue
G
4
2
5
16th
up
2
continue
continue
B
-1
4
2
5
16th
up
2
end
end
A
4
2
5
16th
up
2
begin
begin
G
4
2
5
16th
up
2
continue
continue
F
4
2
5
16th
up
2
continue
continue
E
4
2
5
16th
up
2
end
end
D
4
2
5
16th
up
2
begin
begin
F
4
2
5
16th
up
2
continue
continue
E
4
2
5
16th
up
2
continue
continue
G
4
2
5
16th
up
2
end
end
D
5
4
1
eighth
up
1
begin
F
1
4
4
1
eighth
sharp
up
1
continue
G
1
4
4
1
eighth
sharp
up
1
continue
A
4
4
1
eighth
up
1
end
B
4
4
1
eighth
down
1
begin
C
5
4
1
eighth
down
1
end
D
5
8
1
quarter
down
1
32
F
4
2
5
16th
up
2
begin
begin
E
4
2
5
16th
up
2
continue
continue
D
4
2
5
16th
up
2
continue
continue
C
4
2
5
16th
up
2
end
end
B
3
2
5
16th
up
2
begin
begin
D
4
2
5
16th
up
2
continue
continue
C
4
2
5
16th
up
2
continue
continue
E
4
2
5
16th
up
2
end
end
D
4
2
5
16th
up
2
begin
begin
C
4
2
5
16th
up
2
continue
continue
B
3
2
5
16th
up
2
continue
continue
A
3
2
5
16th
up
2
end
end
G
1
3
2
5
16th
sharp
up
2
begin
begin
B
3
2
5
16th
up
2
continue
continue
A
3
2
5
16th
up
2
continue
continue
C
4
2
5
16th
up
2
end
end
21.00
-0.00
100.66
65.00
F
4
D
5
2
1
16th
up
1
begin
begin
E
4
2
1
16th
up
1
continue
continue
F
1
4
2
1
16th
sharp
up
1
continue
continue
G
1
4
2
1
16th
sharp
up
1
end
end
A
4
2
1
16th
up
1
begin
begin
F
1
4
2
1
16th
up
1
continue
continue
G
1
4
2
1
16th
up
1
continue
continue
E
4
2
1
16th
up
1
end
end
E
5
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
D
5
2
1
16th
down
1
begin
begin
C
5
2
1
16th
down
1
continue
continue
B
4
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
32
B
3
4
5
eighth
down
2
begin
E
3
4
5
eighth
down
2
end
D
4
6
5
eighth
down
2
begin
E
4
2
5
16th
down
2
end
backward hook
C
4
2
5
16th
down
2
begin
begin
B
3
2
5
16th
down
2
continue
continue
A
3
2
5
16th
down
2
continue
continue
G
3
2
5
16th
natural
down
2
end
end
F
1
3
2
5
16th
sharp
down
2
begin
begin
A
3
2
5
16th
down
2
continue
continue
G
1
3
2
5
16th
sharp
down
2
continue
continue
B
3
2
5
16th
down
2
end
end
C
5
2
1
16th
down
1
begin
begin
A
5
2
1
16th
down
1
continue
continue
G
1
5
2
1
16th
sharp
down
1
continue
continue
B
5
2
1
16th
down
1
end
end
A
5
2
1
16th
down
1
begin
begin
E
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
G
1
4
2
1
16th
sharp
down
1
begin
begin
F
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
C
5
4
1
eighth
down
1
begin
B
4
2
1
16th
down
1
continue
begin
A
4
2
1
16th
down
1
end
end
32
A
3
2
5
16th
down
2
begin
begin
C
4
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
continue
continue
D
4
2
5
16th
down
2
end
end
C
4
2
5
16th
down
2
begin
begin
E
4
2
5
16th
down
2
continue
continue
D
4
2
5
16th
down
2
continue
continue
F
4
2
5
16th
down
2
end
end
E
4
4
5
eighth
down
2
begin
A
3
4
5
eighth
down
2
continue
E
4
4
5
eighth
down
2
continue
E
3
4
5
eighth
down
2
end
21.00
-0.00
100.66
65.00
A
4
2
1
16th
down
1
begin
begin
A
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
end
end
E
5
2
1
16th
down
1
begin
begin
G
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
A
5
2
1
16th
down
1
end
end
G
5
16
1
half
down
1
32
A
3
4
5
eighth
down
2
begin
A
2
4
5
eighth
down
2
end
8
5
quarter
2
2
5
16th
2
E
4
2
5
16th
down
2
begin
begin
D
4
2
5
16th
down
2
continue
continue
C
4
2
5
16th
down
2
end
end
B
3
2
5
16th
down
2
begin
begin
D
4
2
5
16th
down
2
continue
continue
C
1
4
2
5
16th
sharp
down
2
continue
continue
E
4
2
5
16th
down
2
end
end
G
5
2
1
16th
down
1
begin
begin
E
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
end
end
A
5
2
1
16th
down
1
begin
begin
F
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
F
5
16
1
half
down
1
32
D
4
16
5
half
down
2
D
4
2
5
16th
down
2
begin
begin
A
3
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
continue
continue
C
4
2
5
16th
down
2
end
end
D
4
2
5
16th
down
2
begin
begin
B
3
2
5
16th
down
2
continue
continue
C
4
2
5
16th
down
2
continue
continue
A
3
2
5
16th
down
2
end
end
21.00
-0.00
100.66
65.00
F
5
2
1
16th
down
1
begin
begin
G
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
D
5
2
1
16th
down
1
begin
begin
F
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
end
end
F
5
16
1
half
down
1
32
B
3
16
5
half
down
2
B
3
2
5
16th
down
2
begin
begin
D
4
2
5
16th
down
2
continue
continue
C
4
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
end
end
A
3
2
5
16th
down
2
begin
begin
C
4
2
5
16th
down
2
continue
continue
B
3
2
5
16th
down
2
continue
continue
D
4
2
5
16th
down
2
end
end
F
5
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
end
end
G
5
2
1
16th
down
1
begin
begin
E
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
end
end
E
5
16
1
half
down
1
32
C
4
16
5
half
down
2
C
4
2
5
16th
down
2
begin
begin
G
3
2
5
16th
down
2
continue
continue
A
3
2
5
16th
down
2
continue
continue
B
-1
3
2
5
16th
flat
down
2
end
end
C
4
2
5
16th
down
2
begin
begin
A
3
2
5
16th
down
2
continue
continue
B
-1
3
2
5
16th
down
2
continue
continue
G
3
2
5
16th
down
2
end
end
21.00
-0.00
100.66
65.00
E
5
2
1
16th
down
1
begin
begin
C
5
2
1
16th
down
1
continue
continue
D
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
F
5
2
1
16th
down
1
begin
begin
D
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
continue
continue
C
5
2
1
16th
down
1
end
end
D
5
2
1
16th
down
1
begin
begin
E
5
2
1
16th
down
1
continue
continue
F
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
end
end
A
5
2
1
16th
down
1
begin
begin
F
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
continue
continue
E
5
2
1
16th
down
1
end
end
32
A
3
4
5
eighth
down
2
begin
B
-1
3
4
5
eighth
flat
down
2
continue
A
3
4
5
eighth
down
2
continue
G
3
4
5
eighth
down
2
end
F
3
4
5
eighth
down
2
begin
D
4
4
5
eighth
down
2
continue
C
4
4
5
eighth
down
2
continue
B
-1
3
4
5
eighth
down
2
end
F
5
2
1
16th
down
1
begin
begin
G
5
2
1
16th
down
1
continue
continue
A
5
2
1
16th
down
1
continue
continue
B
5
2
1
16th
down
1
end
end
C
6
2
1
16th
down
1
begin
begin
A
5
2
1
16th
down
1
continue
continue
B
5
2
1
16th
down
1
continue
continue
G
5
2
1
16th
down
1
end
end
C
6
4
1
eighth
down
1
begin
G
5
4
1
eighth
down
1
end
E
5
4
1
eighth
down
1
begin
D
5
2
1
16th
down
1
continue
begin
C
5
2
1
16th
down
1
end
end
32
A
3
4
5
eighth
down
2
begin
F
4
4
5
eighth
down
2
continue
E
4
4
5
eighth
down
2
continue
D
4
4
5
eighth
down
2
end
E
4
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
continue
continue
F
3
2
5
16th
down
2
end
end
G
3
2
5
16th
down
2
begin
begin
E
3
2
5
16th
down
2
continue
continue
F
3
2
5
16th
down
2
continue
continue
D
3
2
5
16th
down
2
end
end
21.00
-0.00
100.66
65.00
C
5
2
1
16th
up
1
begin
begin
B
-1
4
2
1
16th
flat
up
1
continue
continue
A
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
end
end
F
4
2
1
16th
up
1
begin
begin
A
4
2
1
16th
up
1
continue
continue
G
4
2
1
16th
up
1
continue
continue
B
-1
4
2
1
16th
up
1
end
end
A
4
2
1
16th
up
1
begin
begin
B
4
2
1
16th
natural
up
1
continue
continue
C
5
2
1
16th
up
1
continue
continue
E
4
2
1
16th
up
1
end
end
D
4
2
1
16th
up
1
begin
begin
C
5
2
1
16th
up
1
continue
continue
F
4
2
1
16th
up
1
continue
continue
B
4
2
1
16th
up
1
end
end
32
E
3
4
5
eighth
down
2
begin
C
3
4
5
eighth
down
2
continue
D
3
4
5
eighth
down
2
continue
E
3
4
5
eighth
down
2
end
F
3
2
5
16th
down
2
begin
begin
D
3
2
5
16th
down
2
continue
continue
E
3
2
5
16th
down
2
continue
continue
F
3
2
5
16th
down
2
end
end
G
3
4
5
eighth
up
2
begin
G
2
4
5
eighth
up
2
end
E
4
32
1
whole
1
G
4
32
1
whole
1
C
5
32
1
whole
1
32
C
2
32
5
whole
2
C
3
32
5
whole
2
light-heavy
================================================
FILE: moonlight/testdata/README.md
================================================
# OMR Corpus
Contains test data for unit testing the OMR pipeline. Currently contains a
single page of a music score, to be used in all unit tests, and simple generated
scores containing a few measures.
The score was obtained from [IMSLP](http://imslp.org) and is in the public
domain in the United States.
================================================
FILE: moonlight/testdata/TWO_MEASURE_SAMPLE.LICENSE.md
================================================
The score was composed by the Moonlight authors. The MusicXML transcription is
released under the Apache License, version 2.0, as with all code.
================================================
FILE: moonlight/testdata/TWO_MEASURE_SAMPLE.xml
================================================
MuseScore 1.3
2017-10-13
7.05556
40
1683.78
1190.55
56.6929
56.6929
56.6929
113.386
56.6929
56.6929
56.6929
113.386
Easy Test
Just the basics
brace
Piano
Pno.
Piano
1
1
78.7402
0
87.63
0.00
180.00
1
1
major
4
4
G
2
G
4
1
1
quarter
up
A
4
1
1
quarter
up
B
4
1
1
quarter
down
F
1
4
1
1
quarter
up
G
4
1
1
quarter
up
F
1
4
1
1
quarter
up
G
4
1
1
quarter
up
G
5
1
1
quarter
down
light-heavy
================================================
FILE: moonlight/tools/BUILD
================================================
package(default_visibility = ["//moonlight:__subpackages__"])
licenses(["notice"]) # Apache 2.0
py_binary(
name = "gen_structure_test_case",
srcs = ["gen_structure_test_case.py"],
srcs_version = "PY2AND3",
deps = [
# disable_tf2
# absl dep
"//moonlight:engine",
],
)
py_binary(
name = "export_kmeans_centroids",
srcs = ["export_kmeans_centroids.py"],
srcs_version = "PY2AND3",
deps = [
":export_kmeans_centroids_lib",
# disable_tf2
],
)
py_library(
name = "export_kmeans_centroids_lib",
srcs = ["export_kmeans_centroids.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
# absl dep
"//moonlight/glyphs:corpus",
"//moonlight/glyphs:knn_model",
# tensorflow dep
],
)
py_test(
name = "export_kmeans_centroids_test",
srcs = ["export_kmeans_centroids_test.py"],
srcs_version = "PY2AND3",
# Original centroids were deleted from Piper, this test can no longer be run at head.
# Will delete soon assuming we don't need it again.
tags = [
"manual",
"notap",
],
deps = [
":export_kmeans_centroids_lib",
# disable_tf2
"//moonlight:engine",
"//moonlight/glyphs:saved_classifier",
# tensorflow dep
],
)
================================================
FILE: moonlight/tools/export_kmeans_centroids.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool to convert existing KNN tfrecords to a saved model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
import tensorflow as tf
from moonlight.glyphs import corpus
from moonlight.glyphs import knn_model
def run(tfrecords_filename, export_dir):
with tf.Session():
height, width = corpus.get_patch_shape(tfrecords_filename)
patches, labels = corpus.parse_corpus(tfrecords_filename, height, width)
knn_model.export_knn_model(patches, labels, export_dir)
def main(argv):
_, infile, outdir = argv
run(infile, outdir)
if __name__ == '__main__':
app.run(main)
================================================
FILE: moonlight/tools/export_kmeans_centroids_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End to end test for exporting the KNN model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tempfile
import librosa
import tensorflow as tf
from moonlight import engine
from moonlight.glyphs import saved_classifier
from moonlight.tools import export_kmeans_centroids
class ExportKmeansCentroidsTest(tf.test.TestCase):
def testEndToEnd(self):
with tempfile.TemporaryDirectory() as tmpdir:
with engine.get_included_labels_file() as centroids:
export_dir = os.path.join(tmpdir, 'export')
export_kmeans_centroids.run(centroids.name, export_dir)
# Now load the saved model.
omr = engine.OMREngine(
glyph_classifier_fn=saved_classifier.SavedConvolutional1DClassifier
.glyph_classifier_fn(export_dir))
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../testdata/IMSLP00747-000.png')
notes = omr.run(filename, output_notesequence=True)
# TODO(ringw): Fix the extra note that is detected before the actual
# first eighth note.
self.assertEqual(librosa.note_to_midi('C4'), notes.notes[1].pitch)
self.assertEqual(librosa.note_to_midi('D4'), notes.notes[2].pitch)
self.assertEqual(librosa.note_to_midi('E4'), notes.notes[3].pitch)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/tools/gen_structure_test_case.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prints assert statements about the number of systems, staves, and barlines.
After running the tool, please verify that the output is completely correct
before copying and pasting it into omr_regression_test.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import re
from absl import app
from moonlight import engine
def main(argv):
pages = argv[1:]
assert pages, 'Pass one or more PNG files'
omr = engine.OMREngine()
for i, filename in enumerate(pages):
escaped_filename = re.sub(r'([\'\\])', r'\\\0', filename)
page = omr.run(filename).page[0]
# TODO(ringw): Use a real templating system (e.g. jinja or mako).
if i > 0:
print('')
print(' def test%s_structure(self):' % _sanitized_basename(filename))
print(' page = engine.OMREngine().run(')
print(' \'%s\').page[0]' % escaped_filename)
print(' self.assertEqual(len(page.system), %d)' % len(page.system))
for i, system in enumerate(page.system):
print('')
print(' self.assertEqual(len(page.system[%d].staff), %d)' %
(i, len(system.staff)))
print(' self.assertEqual(len(page.system[%d].bar), %d)' %
(i, len(system.bar)))
def _sanitized_basename(filename):
filename, unused_ext = os.path.splitext(os.path.basename(filename))
return re.sub('[^A-z0-9]+', '_', filename)
if __name__ == '__main__':
app.run(main)
================================================
FILE: moonlight/training/clustering/BUILD
================================================
# Description:
# Unsupervised learning pipeline for OMR glyph classification.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "staffline_patches_dofn",
srcs = ["staffline_patches_dofn.py"],
srcs_version = "PY2AND3",
deps = [
# apache-beam dep
"//moonlight/staves:staffline_extractor",
"//moonlight/util:more_iter_tools",
# numpy dep
# six dep
# tensorflow dep
],
)
py_test(
name = "staffline_patches_dofn_test",
size = "medium",
srcs = ["staffline_patches_dofn_test.py"],
data = ["//moonlight/testdata:images"],
srcs_version = "PY2AND3",
deps = [
":staffline_patches_dofn",
# disable_tf2
# absl/testing dep
# apache-beam dep
# tensorflow dep
],
)
py_binary(
name = "staffline_patches_kmeans_pipeline",
srcs = ["staffline_patches_kmeans_pipeline.py"],
srcs_version = "PY2AND3",
deps = [
":staffline_patches_kmeans_pipeline_lib",
# disable_tf2
],
)
py_library(
name = "staffline_patches_kmeans_pipeline_lib",
srcs = ["staffline_patches_kmeans_pipeline.py"],
srcs_version = "PY2AND3",
deps = [
":staffline_patches_dofn",
# absl dep
# apache-beam dep
"//moonlight/pipeline:pipeline_flags",
# tensorflow dep
# tensorflow.contrib.learn dep
],
)
py_test(
name = "staffline_patches_kmeans_pipeline_test",
srcs = ["staffline_patches_kmeans_pipeline_test.py"],
srcs_version = "PY2AND3",
deps = [
":staffline_patches_kmeans_pipeline_lib",
# disable_tf2
# absl dep
# absl/testing dep
# numpy dep
# tensorflow dep
],
)
py_library(
name = "kmeans_labeler_request_handler",
srcs = ["kmeans_labeler_request_handler.py"],
data = ["kmeans_labeler_template.html"],
srcs_version = "PY2AND3",
deps = [
# pillow dep
# mako dep
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# six dep
# tensorflow dep
],
)
py_test(
name = "kmeans_labeler_request_handler_test",
srcs = ["kmeans_labeler_request_handler_test.py"],
srcs_version = "PY2AND3",
deps = [
":kmeans_labeler_request_handler",
# disable_tf2
# absl/testing dep
# mako dep
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# tensorflow dep
],
)
py_binary(
name = "kmeans_labeler",
srcs = ["kmeans_labeler.py"],
srcs_version = "PY2AND3",
deps = [
":kmeans_labeler_request_handler",
# disable_tf2
# pillow dep
# absl dep
# mako dep
"//moonlight/protobuf:protobuf_py_pb2",
# numpy dep
# six dep
# tensorflow dep
],
)
================================================
FILE: moonlight/training/clustering/kmeans_labeler.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Serves an HTML page for labeling the k-means clusters.
Takes cluster TFRecords generated by `staffline_patches_kmeans_pipeline.py` and
serves an HTML page containing cluster images. When the HTML form containing
labels is submitted, it (by defaults) overwrites the clusters file, adding the
labels.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import numpy as np
from six.moves import BaseHTTPServer
import tensorflow as tf
from tensorflow.python.lib.io import tf_record
from moonlight.training.clustering import kmeans_labeler_request_handler as handler
FLAGS = flags.FLAGS
flags.DEFINE_string('clusters_path', None, 'Path to the input patch TFRecords')
flags.DEFINE_string(
'output_path', None, 'Path to the output patch TFRecords. Defaults to '
'overwriting clusters_path.')
flags.DEFINE_integer('port', 8000, 'Port for serving the labeler.')
def load_clusters(input_path):
"""Loads TFRecords of Examples representing the k-means clusters.
Examples are typically the output of `staffline_patches_kmeans_pipeline.py`.
Args:
input_path: Path to the TFRecords of Examples.
Returns:
A NumPy array of shape (num_clusters, patch_height, patch_width).
"""
def parse_example(example_str):
example = tf.train.Example()
example.ParseFromString(example_str)
height = example.features.feature['height'].int64_list.value[0]
width = example.features.feature['width'].int64_list.value[0]
return np.asarray(
example.features.feature['features'].float_list.value).reshape(
(height, width))
return np.asarray([
parse_example(example)
for example in tf_record.tf_record_iterator(input_path)
])
def main(_):
server_address = ('localhost', FLAGS.port)
clusters = load_clusters(FLAGS.clusters_path)
# Default to overwriting the input.
output_path = FLAGS.output_path or FLAGS.clusters_path
def create_request_handler(*args):
return handler.LabelerHTTPRequestHandler(clusters, output_path, *args)
server = BaseHTTPServer.HTTPServer(server_address, create_request_handler)
print('Listening on port %d...', FLAGS.port)
server.serve_forever()
if __name__ == '__main__':
tf.app.run()
================================================
FILE: moonlight/training/clustering/kmeans_labeler_request_handler.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The k-means labeler HTTP request handler.
Displays the clusters on an HTML page with dropdowns for each cluster's label.
On POST, updates the output to include the cluster labels.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import cgi
from mako import template as mako_template
import numpy as np
from PIL import Image
import six
from six import moves
from six.moves import BaseHTTPServer
from six.moves import http_client
from tensorflow.core.example import example_pb2
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import resource_loader
from moonlight.protobuf import musicscore_pb2
class LabelerHTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""The HTTP request handler for the k-means labeler server.
Attributes:
clusters: NumPy array of clusters. Shape (num_clusters, patch_height,
patch_width).
output_path: Path to write the TFRecords.
"""
def __init__(self, clusters, output_path, *args):
self.clusters = clusters
self.output_path = output_path
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args)
def do_GET(self):
template_path = resource_loader.get_path_to_datafile(
'kmeans_labeler_template.html')
template = mako_template.Template(open(template_path).read())
page = create_page(self.clusters, template)
self.send_response(http_client.OK)
self.send_header('Content-Type', 'text/html; charset=utf-8')
self.end_headers()
self.wfile.write(page)
def do_POST(self):
post_vars = cgi.parse_qs(
self.rfile.read(int(self.headers.getheader('content-length'))))
labels = [
post_vars['cluster%d' % i][0]
for i in moves.xrange(self.clusters.shape[0])
]
examples = create_examples(self.clusters, labels)
with tf_record.TFRecordWriter(self.output_path) as writer:
for example in examples:
writer.write(example.SerializeToString())
self.send_response(http_client.OK)
self.end_headers()
self.wfile.write('Success') # printed in the labeler alert
def create_page(clusters, template):
"""Renders the labeler HTML.
Args:
clusters: NumPy array of clusters.
template: Mako template for the page.
Returns:
The labeler HTML string.
"""
# Tuples (index, preview, is_content).
cluster_info = [
(i,) + _process_cluster(cluster) for i, cluster in enumerate(clusters)
]
content_clusters = [cluster for cluster in cluster_info if cluster[2]]
non_content_clusters = [cluster for cluster in cluster_info if not cluster[2]]
return template.render(
content_clusters=content_clusters,
empty_clusters=non_content_clusters,
# Skip the unknown glyph type.
glyph_types=musicscore_pb2.Glyph.Type.keys()[1:])
def _process_cluster(cluster):
"""Processes a cluster cluster image.
Args:
cluster: 2D NumPy array.
Returns:
The preview image (PNG encoded in an HTML data URL).
is_content: boolean; whether the cluster is considered content.
"""
image_arr = create_highlighted_image(cluster)
image = Image.fromarray(image_arr)
buf = six.BytesIO()
image.save(buf, 'PNG', optimize=True)
buf.seek(0)
preview = 'data:image/png;base64,' + str(base64.b64encode(buf.read()))
# The cluster is likely non-content (does not contain a glyph) if the
# max standard deviation across all rows or all columns is low. Show those
# patches at the bottom of the page so that they can still be double checked
# by hand.
is_content = min(cluster.std(axis=0).max(), cluster.std(axis=1).max()) > 0.1
return preview, is_content
def create_highlighted_image(cluster, enlarge_ratio=10):
"""Enlarges the "cluster" image and draws a crosshairs highlight.
Args:
cluster: 2D NumPy array.
enlarge_ratio: Scale of the output image.
Returns:
An enlarged and highlighted image as a NumPy array. 3D (height, width, 3).
A "crosshairs" pattern is drawn which highlights the center pixel of the
image.
"""
# Enlarge the image, and repeat on the 3rd axis to get RGB channels.
image_arr = np.repeat(
np.repeat(
np.repeat(
(cluster * 255).astype(np.uint8)[:, :, None],
enlarge_ratio,
axis=0),
enlarge_ratio,
axis=1),
3,
axis=2)
# Calculate the vertical and horizontal slice of the image to be highlighted.
vertical_start = (image_arr.shape[0] - enlarge_ratio) // 2
vertical_stop = (image_arr.shape[0] + enlarge_ratio) // 2
vertical_slice = slice(vertical_start, vertical_stop)
horiz_slice = slice(image_arr.shape[1] // 2 - 5, image_arr.shape[1] // 2 + 6)
# Highlight the vertical slice of the image to be partially red.
image_arr[vertical_slice] = (
image_arr[vertical_slice] / 2 + np.array([127., 0., 0.]))
# Highlight the horizontal slice of the image, avoiding the section that was
# already highlighted.
image_arr[:vertical_start, horiz_slice] = (
image_arr[:vertical_start, horiz_slice] / 2 + np.array([127., 0., 0.]))
image_arr[vertical_stop:, horiz_slice] = (
image_arr[vertical_stop:, horiz_slice] / 2 + np.array([127., 0., 0.]))
return image_arr
def create_examples(clusters, labels):
"""Creates Examples from the clusters and label strings.
Args:
clusters: NumPy array of shape (num_clusters, patch_height, patch_width).
labels: List of string labels, which are names in musicscore_pb2.Glyph.Type.
Length `num_clusters`.
Returns:
A list of Example protos of length `num_clusters`.
"""
examples = []
for cluster, label in zip(clusters, labels):
example = example_pb2.Example()
features = example.features
features.feature['patch'].float_list.value.extend(cluster.ravel())
features.feature['height'].int64_list.value.append(cluster.shape[0])
features.feature['width'].int64_list.value.append(cluster.shape[1])
label_num = musicscore_pb2.Glyph.Type.Value(label)
example.features.feature['label'].int64_list.value.append(label_num)
examples.append(example)
return examples
================================================
FILE: moonlight/training/clustering/kmeans_labeler_request_handler_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the k-means labeler request handler."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from absl.testing import absltest
from mako import template as mako_template
import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.platform import resource_loader
from moonlight.protobuf import musicscore_pb2
from moonlight.training.clustering import kmeans_labeler_request_handler
class KmeansLabelerRequestHandlerTest(absltest.TestCase):
def testCreatePage(self):
num_clusters = 10
clusters = np.random.random((num_clusters, 12, 14))
template_src = resource_loader.get_path_to_datafile(
'kmeans_labeler_template.html')
template = mako_template.Template(open(template_src).read())
html = kmeans_labeler_request_handler.create_page(clusters, template)
self.assertEqual(len(re.findall('
## cluster is a tuple (index, preview, is_content).
% for glyph_type in glyph_types:
${ glyph_type }
% endfor
%def>
Magenta OMR Patch Labeler
================================================
FILE: moonlight/training/clustering/staffline_patches_dofn.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracts non-empty patches of extracted stafflines.
Extracts vertical slices of the image where glyphs are expected
(see `staffline_extractor.py`), and takes horizontal windows of the slice which
will be clustered. Some patches will have a glyph roughly in their center, and
the corresponding cluster centroids will be labeled as such.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import apache_beam as beam
from apache_beam import metrics
from moonlight.staves import staffline_extractor
from moonlight.util import more_iter_tools
import numpy as np
from six.moves import filter
import tensorflow as tf
def _filter_patch(patch, min_num_dark_pixels=10):
unused_patch_name, patch = patch
return np.greater_equal(np.sum(np.less(patch, 0.5)), min_num_dark_pixels)
class StafflinePatchesDoFn(beam.DoFn):
"""Runs the staffline patches graph."""
def __init__(self, patch_height, patch_width, num_stafflines, timeout_ms,
max_patches_per_page):
self.patch_height = patch_height
self.patch_width = patch_width
self.num_stafflines = num_stafflines
self.timeout_ms = timeout_ms
self.max_patches_per_page = max_patches_per_page
self.total_pages_counter = metrics.Metrics.counter(self.__class__,
'total_pages')
self.failed_pages_counter = metrics.Metrics.counter(self.__class__,
'failed_pages')
self.successful_pages_counter = metrics.Metrics.counter(
self.__class__, 'successful_pages')
self.empty_pages_counter = metrics.Metrics.counter(self.__class__,
'empty_pages')
self.total_patches_counter = metrics.Metrics.counter(
self.__class__, 'total_patches')
self.emitted_patches_counter = metrics.Metrics.counter(
self.__class__, 'emitted_patches')
def start_bundle(self):
self.extractor = staffline_extractor.StafflinePatchExtractor(
patch_height=self.patch_height,
patch_width=self.patch_width,
run_options=tf.RunOptions(timeout_in_ms=self.timeout_ms))
self.session = tf.Session(graph=self.extractor.graph)
def process(self, png_path):
self.total_pages_counter.inc()
try:
with self.session.as_default():
patches_iter = self.extractor.page_patch_iterator(png_path)
# pylint: disable=broad-except
except Exception:
logging.exception('Skipping failed music score (%s)', png_path)
self.failed_pages_counter.inc()
return
patches_iter = filter(_filter_patch, patches_iter)
if 0 < self.max_patches_per_page:
# Subsample patches.
patches = more_iter_tools.iter_sample(patches_iter,
self.max_patches_per_page)
else:
patches = list(patches_iter)
if not patches:
self.empty_pages_counter.inc()
self.total_patches_counter.inc(len(patches))
# Serialize each patch as an Example.
for patch_name, patch in patches:
example = tf.train.Example()
example.features.feature['name'].bytes_list.value.append(
patch_name.encode('utf-8'))
example.features.feature['features'].float_list.value.extend(
patch.ravel())
example.features.feature['height'].int64_list.value.append(patch.shape[0])
example.features.feature['width'].int64_list.value.append(patch.shape[1])
yield example
self.successful_pages_counter.inc()
# Patches are sub-sampled by this point.
self.emitted_patches_counter.inc(len(patches))
def finish_bundle(self):
self.session.close()
del self.extractor
del self.session
================================================
FILE: moonlight/training/clustering/staffline_patches_dofn_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the staffline patches DoFn graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tempfile
from absl.testing import absltest
import apache_beam as beam
import tensorflow as tf
from tensorflow.python.lib.io import tf_record
from moonlight.staves import staffline_extractor
from moonlight.training.clustering import staffline_patches_dofn
PATCH_HEIGHT = 9
PATCH_WIDTH = 7
NUM_STAFFLINES = 9
TIMEOUT_MS = 60000
MAX_PATCHES_PER_PAGE = 10
class StafflinePatchesDoFnTest(absltest.TestCase):
def testPipeline_corpusImage(self):
filename = os.path.join(tf.resource_loader.get_data_files_path(),
'../../testdata/IMSLP00747-000.png')
with tempfile.NamedTemporaryFile() as output_examples:
# Run the pipeline to get the staffline patches.
with beam.Pipeline() as pipeline:
dofn = staffline_patches_dofn.StafflinePatchesDoFn(
PATCH_HEIGHT, PATCH_WIDTH, NUM_STAFFLINES, TIMEOUT_MS,
MAX_PATCHES_PER_PAGE)
# pylint: disable=expression-not-assigned
(pipeline | beam.transforms.Create([filename])
| beam.transforms.ParDo(dofn) | beam.io.WriteToTFRecord(
output_examples.name,
beam.coders.ProtoCoder(tf.train.Example),
shard_name_template=''))
# Get the staffline images from a local TensorFlow session.
extractor = staffline_extractor.StafflinePatchExtractor(
staffline_extractor.DEFAULT_NUM_SECTIONS, PATCH_HEIGHT, PATCH_WIDTH)
with tf.Session(graph=extractor.graph):
expected_patches = [
tuple(patch.ravel())
for unused_key, patch in extractor.page_patch_iterator(filename)
]
for example_bytes in tf_record.tf_record_iterator(output_examples.name):
example = tf.train.Example()
example.ParseFromString(example_bytes)
patch_pixels = tuple(
example.features.feature['features'].float_list.value)
if patch_pixels not in expected_patches:
self.fail('Missing patch {}'.format(patch_pixels))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/training/clustering/staffline_patches_kmeans_pipeline.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Glyph classification unsupervised pipeline.
Extracts patches from stafflines, and runs k-means on the patches. Each patch
will be labeled and used for k-nearest-neighbors classification.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import random
import shutil
import tempfile
from absl import flags
import apache_beam as beam
from apache_beam.transforms import combiners
from moonlight.pipeline import pipeline_flags
from moonlight.training.clustering import staffline_patches_dofn
import tensorflow as tf
from tensorflow.contrib import learn as contrib_learn
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.python.lib.io import file_io
from tensorflow.python.lib.io import tf_record
FLAGS = flags.FLAGS
flags.DEFINE_multi_string('music_pattern', [],
'Pattern for the input music score PNGs.')
flags.DEFINE_string('output_path', None, 'Path to the output TFRecords.')
flags.DEFINE_integer('patch_height', 18,
'The normalized height of a staffline.')
flags.DEFINE_integer('patch_width', 15,
'The width of a horizontal patch of a staffline.')
flags.DEFINE_integer('num_stafflines', 19,
'The number of stafflines to extract.')
flags.DEFINE_integer('num_pages', 0, 'Subsample the pages to run on.')
flags.DEFINE_integer('num_outputs', 0, 'Number of output patches.')
flags.DEFINE_integer('max_patches_per_page', 10,
'Sample patches per page if above this amount.')
flags.DEFINE_integer('timeout_ms', 600000, 'Timeout for processing a page.')
flags.DEFINE_integer('kmeans_num_clusters', 1000, 'Number of k-means clusters.')
flags.DEFINE_integer('kmeans_batch_size', 10000,
'Batch size for mini-batch k-means.')
flags.DEFINE_integer('kmeans_num_steps', 100,
'Number of k-means training steps.')
def train_kmeans(patch_file_pattern,
num_clusters,
batch_size,
train_steps,
min_eval_frequency=None):
"""Runs TensorFlow K-Means over TFRecords.
Args:
patch_file_pattern: Pattern that matches TFRecord file(s) holding Examples
with image patches.
num_clusters: Number of output clusters.
batch_size: Size of a k-means minibatch.
train_steps: Number of steps for k-means training.
min_eval_frequency: The minimum number of steps between evaluations. Of
course, evaluation does not occur if no new snapshot is available, hence,
this is the minimum. If 0, the evaluation will only happen after
training. If None, defaults to 1. To avoid checking for new checkpoints
too frequent, the interval is further limited to be at least
check_interval_secs between checks. See
third_party/tensorflow/contrib/learn/python/learn/experiment.py for
details.
Returns:
A NumPy array of shape (num_clusters, patch_height * patch_width). The
cluster centers.
"""
def input_fn():
"""The tf.learn input_fn.
Returns:
features, a float32 tensor of shape
(batch_size, patch_height * patch_width).
None for labels (not applicable to k-means).
"""
examples = contrib_learn.read_batch_examples(
patch_file_pattern,
batch_size,
tf.TFRecordReader,
queue_capacity=batch_size * 2)
features = tf.parse_example(
examples, {
'features':
tf.FixedLenFeature(FLAGS.patch_height * FLAGS.patch_width,
tf.float32)
})['features']
return features, None # no labels
def experiment_fn(run_config, unused_hparams):
"""The tf.learn experiment_fn.
Args:
run_config: The run config to be passed to the KMeansClustering.
unused_hparams: Hyperparameters; not applicable.
Returns:
A tf.contrib.learn.Experiment.
"""
kmeans = contrib_learn.KMeansClustering(
num_clusters=num_clusters, config=run_config)
return contrib_learn.Experiment(
estimator=kmeans,
train_steps=train_steps,
train_input_fn=input_fn,
eval_steps=1,
eval_input_fn=input_fn,
min_eval_frequency=min_eval_frequency)
output_dir = tempfile.mkdtemp(prefix='staffline_patches_kmeans')
try:
learn_runner.run(
experiment_fn, run_config=contrib_learn.RunConfig(model_dir=output_dir))
num_features = FLAGS.patch_height * FLAGS.patch_width
clusters_t = tf.Variable(
tf.zeros((num_clusters, num_features)), # Dummy init op
name='clusters')
with tf.Session() as sess:
tf.train.Saver(var_list=[clusters_t]).restore(
sess, os.path.join(output_dir, 'model.ckpt-%d' % train_steps))
return clusters_t.eval()
finally:
shutil.rmtree(output_dir)
def main(_):
tf.logging.info('Building the pipeline...')
records_dir = tempfile.mkdtemp(prefix='staffline_kmeans')
try:
patch_file_prefix = os.path.join(records_dir, 'patches')
with pipeline_flags.create_pipeline() as pipeline:
filenames = file_io.get_matching_files(FLAGS.music_pattern)
assert filenames, 'Must have matched some filenames'
if 0 < FLAGS.num_pages < len(filenames):
filenames = random.sample(filenames, FLAGS.num_pages)
filenames = pipeline | beam.transforms.Create(filenames)
patches = filenames | beam.ParDo(
staffline_patches_dofn.StafflinePatchesDoFn(
patch_height=FLAGS.patch_height,
patch_width=FLAGS.patch_width,
num_stafflines=FLAGS.num_stafflines,
timeout_ms=FLAGS.timeout_ms,
max_patches_per_page=FLAGS.max_patches_per_page))
if FLAGS.num_outputs:
patches |= combiners.Sample.FixedSizeGlobally(FLAGS.num_outputs)
patches |= beam.io.WriteToTFRecord(
patch_file_prefix, beam.coders.ProtoCoder(tf.train.Example))
tf.logging.info('Running the pipeline...')
tf.logging.info('Running k-means...')
patch_files = file_io.get_matching_files(patch_file_prefix + '*')
clusters = train_kmeans(patch_files, FLAGS.kmeans_num_clusters,
FLAGS.kmeans_batch_size, FLAGS.kmeans_num_steps)
tf.logging.info('Writing the centroids...')
with tf_record.TFRecordWriter(FLAGS.output_path) as writer:
for cluster in clusters:
example = tf.train.Example()
example.features.feature['features'].float_list.value.extend(cluster)
example.features.feature['height'].int64_list.value.append(
FLAGS.patch_height)
example.features.feature['width'].int64_list.value.append(
FLAGS.patch_width)
writer.write(example.SerializeToString())
tf.logging.info('Done!')
finally:
shutil.rmtree(records_dir)
if __name__ == '__main__':
tf.app.run(main)
================================================
FILE: moonlight/training/clustering/staffline_patches_kmeans_pipeline_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for staffline patches k-means."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import flags
from absl.testing import absltest
import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.python.lib.io import tf_record
from moonlight.training.clustering import staffline_patches_kmeans_pipeline
FLAGS = flags.FLAGS
NUM_CLUSTERS = 20
BATCH_SIZE = 100
TRAIN_STEPS = 5
class StafflinePatchesKmeansPipelineTest(absltest.TestCase):
def testKmeans(self):
num_features = FLAGS.patch_height * FLAGS.patch_width
dummy_data = np.random.random((500, num_features))
with tempfile.NamedTemporaryFile(mode='r') as patches_file:
with tf_record.TFRecordWriter(patches_file.name) as patches_writer:
for patch in dummy_data:
example = example_pb2.Example()
example.features.feature['features'].float_list.value.extend(patch)
patches_writer.write(example.SerializeToString())
clusters = staffline_patches_kmeans_pipeline.train_kmeans(
patches_file.name,
NUM_CLUSTERS,
BATCH_SIZE,
TRAIN_STEPS,
min_eval_frequency=0)
self.assertEqual(clusters.shape, (NUM_CLUSTERS, num_features))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/training/generation/BUILD
================================================
# Description:
# Generated training data for OMR.
package(default_visibility = ["//moonlight:__subpackages__"])
licenses(["notice"]) # Apache 2.0
py_library(
name = "generation",
srcs = ["generation.py"],
data = ["vexflow_generator.js"],
srcs_version = "PY2AND3",
deps = [
"@com_google_protobuf//:protobuf_python",
# apache-beam dep
"//moonlight:engine",
"//moonlight/protobuf:protobuf_py_pb2",
"//moonlight/staves:staffline_distance",
"//moonlight/staves:staffline_extractor",
# numpy dep
# tensorflow dep
],
)
py_test(
name = "generation_test",
srcs = ["generation_test.py"],
srcs_version = "PY2AND3",
tags = ["manual"],
deps = [
":generation",
# disable_tf2
"//moonlight/staves:staffline_extractor",
# tensorflow dep
],
)
py_library(
name = "image_noise",
srcs = ["image_noise.py"],
srcs_version = "PY2AND3",
deps = [
# tensorflow dep
# tensorflow.contrib.image py dep
],
)
py_binary(
name = "vexflow_generator_pipeline",
srcs = ["vexflow_generator_pipeline.py"],
srcs_version = "PY2AND3",
deps = [
":generation",
":image_noise",
# disable_tf2
# absl dep
# apache-beam dep
"//moonlight/pipeline:pipeline_flags",
# tensorflow dep
],
)
================================================
FILE: moonlight/training/generation/generation.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VexFlow labeled data generation.
Wraps the node.js generator, which generates a random measure of music as SVG,
and the ground truth glyphs present in the image as a `Page` message.
Each invocation generates a batch of images. There is a tradeoff between the
startup time of node.js for each invocation, and keeping the output size small
enough to pipe into Python.
The final outputs are positive and negative example patches. Positive examples
are centered on an outputted glyph, and have that glyph's type. Negative
examples are at least a few pixels away from any glyph, and have type NONE.
Since negative examples could be a few pixels away from a glyph, we get negative
examples that overlap with partial glyph(s), but are centered too far away from
a glyph to be considered a positive example. Currently, every single glyph
results in a single positive example, and negative examples are randomly
sampled.
All glyphs are emitted to RecordIO, where they are outputted in a single
collection for training. We currently do not store the entire generated image
anywhere. This could be added later in order to try other classification
approaches.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os.path
import random
import subprocess
import sys
import apache_beam as beam
from apache_beam.metrics import Metrics
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from moonlight import engine
from moonlight.protobuf import musicscore_pb2
from moonlight.staves import staffline_distance
from moonlight.staves import staffline_extractor
# Every image is expected to contain at least 3 glyphs.
POSITIVE_EXAMPLES_PER_IMAGE = 3
def _normalize_path(filename):
"""Normalizes a relative path to a command to spawn.
Args:
filename: String; relative or absolute path.
Returns:
The normalized path. This is necessary because in our use case,
vexflow_generator_pipeline will live in a different directory from
vexflow_generator, and there are symlinks to both directories in the same
parent directory. Without normalization, `..` would reference the parent of
the actual directory that was symlinked. With normalization, it references
the directory that contains the symlink to the working directory.
"""
if filename.startswith('/'):
return filename
else:
return os.path.normpath(
os.path.join(os.path.dirname(sys.argv[0]), filename))
class PageGenerationDoFn(beam.DoFn):
"""Generates the PNG images and ground truth for each batch.
Takes in a batch number, and outputs a tuple of PNG contents (bytes) and the
labeled staff (Staff message).
"""
def __init__(self, num_pages_per_batch, vexflow_generator_command,
svg_to_png_command):
self.num_pages_per_batch = num_pages_per_batch
self.vexflow_generator_command = vexflow_generator_command
self.svg_to_png_command = svg_to_png_command
def process(self, batch_num):
for page in self.get_pages_for_batch(batch_num, self.num_pages_per_batch):
staff = musicscore_pb2.Staff()
text_format.Parse(page['page'], staff)
# TODO(ringw): Fix the internal proto pickling issue so that we don't
# have to serialize the staff here.
yield self._svg_to_png(page['svg']), staff.SerializeToString()
def get_pages_for_batch(self, batch_num, num_pages_per_batch):
"""Generates the music score pages in a single batch.
The generator takes in a seed for the RNG for each page, and outputs all
pages at once. The seeds for all batches are consecutive for determinism,
starting from 0, but each seed to the Mersenne Twister RNG should result in
completely different output.
Args:
batch_num: The index of the batch to output.
num_pages_per_batch: The number of pages to generate in each batch.
Returns:
A list of dicts holding `svg` (XML text) and `page` (text-format
`tensorflow.moonlight.Staff` proto).
"""
return self.get_pages(
range(batch_num * num_pages_per_batch,
(batch_num + 1) * num_pages_per_batch))
def get_pages(self, seeds):
vexflow_generator_command = list(self.vexflow_generator_command)
# If vexflow_generator_command is relative, it is relative to the pipeline
# binary.
vexflow_generator_command[0] = _normalize_path(vexflow_generator_command[0])
seeds = ','.join(map(str, seeds))
return json.loads(
subprocess.check_output(vexflow_generator_command +
['--random_seeds=' + seeds]))
def _svg_to_png(self, svg):
svg_to_png_command = list(self.svg_to_png_command)
svg_to_png_command[0] = _normalize_path(svg_to_png_command[0])
popen = subprocess.Popen(
svg_to_png_command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, stderr = popen.communicate(input=svg)
if popen.returncode != 0:
raise ValueError('convert failed with status %d\nstderr:\n%s' %
(popen.returncode, stderr))
return stdout
class PatchExampleDoFn(beam.DoFn):
"""Extracts labeled patches from generated VexFlow music scores."""
def __init__(self,
negative_example_distance,
patch_width,
negative_to_positive_example_ratio,
noise_fn=lambda x: x):
self.negative_example_distance = negative_example_distance
self.patch_width = patch_width
self.negative_to_positive_example_ratio = negative_to_positive_example_ratio
self.noise_fn = noise_fn
self.patch_counter = Metrics.counter(self.__class__, 'num_patches')
def start_bundle(self):
# TODO(ringw): Expose a cleaner way to set this value.
# The image is too small for the default min staffline distance score.
# pylint: disable=protected-access
staffline_distance._MIN_STAFFLINE_DISTANCE_SCORE = 100
self.omr = engine.OMREngine()
def process(self, item):
png_contents, staff_message = item
staff_message = musicscore_pb2.Staff.FromString(staff_message)
with tf.Session(graph=self.omr.graph) as sess:
# Load the image, then feed it in to apply noise.
# Randomly rotate the image and apply noise, then dump it back out as a
# PNG.
# TODO(ringw): Expose a way to pass in the image contents to the main
# OMR TF graph.
img = tf.to_float(tf.image.decode_png(png_contents))
# Collapse the RGB channels, if any. No-op for a monochrome PNG.
img = tf.reduce_mean(img[:, :, :3], axis=2)[:, :, None]
# Fix the stafflines being #999.
img = tf.clip_by_value(img * 2. - 255., 0., 255.)
img = self.noise_fn(img)
# Get a 2D uint8 image array for OMR.
noisy_image = sess.run(
tf.cast(tf.clip_by_value(img, 0, 255)[:, :, 0], tf.uint8))
# Run OMR staffline extraction and staffline distance estimation. The
# stafflines are used to get patches from the generated image.
stafflines, image_staffline_distance = sess.run(
[
self.omr.glyph_classifier.staffline_extractor.extract_staves(),
self.omr.structure.staff_detector.staffline_distance[0]
],
feed_dict={self.omr.image: noisy_image})
if stafflines.shape[0] != 1:
raise ValueError('Image should have one detected staff, got shape: ' +
str(stafflines.shape))
positive_example_count = 0
negative_example_whitelist = np.ones(
(stafflines.shape[staffline_extractor.Axes.POSITION],
stafflines.shape[staffline_extractor.Axes.X]), np.bool)
# Blacklist xs where the patch would overlap with either end.
negative_example_overlap_from_end = max(self.negative_example_distance,
self.patch_width // 2)
negative_example_whitelist[:, :negative_example_overlap_from_end] = False
negative_example_whitelist[:,
-negative_example_overlap_from_end - 1:] = False
all_positive_examples = []
for glyph in staff_message.glyph:
staffline = staffline_extractor.get_staffline(glyph.y_position,
stafflines[0])
glyph_x = int(
round(glyph.x *
self.omr.glyph_classifier.staffline_extractor.target_height /
(image_staffline_distance * self.omr.glyph_classifier
.staffline_extractor.staffline_distance_multiple)))
example = self._create_example(staffline, glyph_x, glyph.type)
if example:
staffline_index = staffline_extractor.y_position_to_index(
glyph.y_position,
stafflines.shape[staffline_extractor.Axes.POSITION])
# Blacklist the area adjacent to the glyph, even if it is not selected
# as a positive example below.
negative_example_whitelist[staffline_index, glyph_x -
self.negative_example_distance + 1:glyph_x +
self.negative_example_distance] = False
all_positive_examples.append(example)
positive_example_count += 1
for example in random.sample(all_positive_examples,
POSITIVE_EXAMPLES_PER_IMAGE):
yield example
self.patch_counter.inc()
negative_example_staffline, negative_example_x = np.where(
negative_example_whitelist)
negative_example_inds = np.random.choice(
len(negative_example_staffline),
int(positive_example_count * self.negative_to_positive_example_ratio))
negative_example_staffline = negative_example_staffline[
negative_example_inds]
negative_example_x = negative_example_x[negative_example_inds]
for staffline, x in zip(negative_example_staffline, negative_example_x):
example = self._create_example(stafflines[0, staffline], x,
musicscore_pb2.Glyph.NONE)
assert example, 'Negative example xs should always be in range'
yield example
self.patch_counter.inc()
def _create_example(self, staffline, x, label):
start_x = x - self.patch_width // 2
limit_x = x + self.patch_width // 2 + 1
assert limit_x - start_x == self.patch_width
# x is the last axis of staffline
if 0 <= start_x <= limit_x < staffline.shape[-1]:
patch = staffline[:, start_x:limit_x]
example = tf.train.Example()
example.features.feature['patch'].float_list.value.extend(patch.ravel())
example.features.feature['label'].int64_list.value.append(label)
example.features.feature['height'].int64_list.value.append(patch.shape[0])
example.features.feature['width'].int64_list.value.append(patch.shape[1])
return example
else:
return None
================================================
FILE: moonlight/training/generation/generation_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for labeled data generation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
from moonlight.staves import staffline_extractor
from moonlight.training.generation import generation
PATCH_WIDTH = 15
class GenerationTest(tf.test.TestCase):
def testDoFn(self):
page_gen = generation.PageGenerationDoFn(
num_pages_per_batch=1,
vexflow_generator_command=[
os.path.join(tf.resource_loader.get_data_files_path(),
'vexflow_generator')
],
svg_to_png_command=['/usr/bin/env', 'convert', 'svg:-', 'png:-'])
page_gen.start_bundle()
patch_examples = generation.PatchExampleDoFn(
negative_example_distance=5,
patch_width=PATCH_WIDTH,
negative_to_positive_example_ratio=1.0)
patch_examples.start_bundle()
examples = [
example for page in page_gen.process(0)
for example in patch_examples.process(page)
]
page_gen.finish_bundle()
patch_examples.finish_bundle()
self.assertGreater(len(examples), 4)
for example in examples:
self.assertEqual(
len(example.features.feature['patch'].float_list.value),
PATCH_WIDTH * staffline_extractor.DEFAULT_TARGET_HEIGHT)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/training/generation/image_noise.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Applies noise to an image for generating training data.
All noise assumes a monochrome image with white (255) as background.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from tensorflow.contrib import image as contrib_image
def placeholder_image():
return tf.placeholder(tf.uint8, shape=(None, None), name='placeholder_image')
def random_rotation(image, angle=math.pi / 180):
return 255. - contrib_image.rotate(
255. - tf.to_float(image),
tf.random_uniform((), -angle, angle),
interpolation='BILINEAR')
def gaussian_noise(image, stddev=5):
return image + tf.random_normal(tf.shape(image), stddev=stddev)
================================================
FILE: moonlight/training/generation/vexflow_generator.js
================================================
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @fileoverview VexFlow random labeled data generator.
*
* Outputs a JSON list of dicts with "svg" (XML text SVG image) and "page" (text
* format Staff proto holding the glyph coordinates).
*/
const ArgumentParser = require('argparse').ArgumentParser;
const jsdom = require('jsdom');
const Random = require('random-js');
const Vex = require('vexflow');
const VF = Vex.Flow;
const parser = new ArgumentParser();
parser.addArgument(
['--random_seeds'],
{help: 'Generate a labeled image for each comma-separated random seed.'});
const args = parser.parseArgs();
/**
* @param {?object} value Any value
* @param {string=} opt_message The message, if value is null or undefined
* @return {!object} The non-null value
* @throws {ValueError} if value is null or undefined
*/
function checkNotNull(value, opt_message) {
// undefined == null too.
if (value == null) {
throw ValueError(opt_message);
}
return value;
}
/**
* VexFlow line numbers start at the first ledger line below the staff, and
* increment by a half for each note. OMR y positions start at the third staff
* line, and increment by 1 for each note.
* @param {number} line The VexFlow line number.
* @return {number} An integer y position.
*/
function vexflowLineToOMR(line) {
return (line - 3) * 2;
}
/**
* Converts an absolute Y coordinate on a staff to an OMR y position.
* @param {!Vex.Flow.Stave} stave The VexFlow stave.
* @param {number} y The y coordinate, in pixels.
* @return {number} An OMR y position.
*/
function absoluteYToOMR(stave, y) {
const staff_center_y = stave.getYForLine(2);
return Math.round((staff_center_y - y) * 2 / stave.space(1));
}
let allGlyphs, staffCenterLine, stafflineDistance;
/** Resets the dumped VexFlow information before starting a new page. */
function resetPageState() {
allGlyphs = [];
staffCenterLine = null;
stafflineDistance = null;
}
resetPageState();
const CLEF_LINE_FOR_OMR = {
// Treble or "G" clef is centered 2 spaces below the center line (on G).
'treble': -2,
// Bass or "F" clef is centered 2 spaces above the center line (on F).
'bass': +2
};
const drawClef = checkNotNull(VF.Clef.prototype.draw, 'Clef.draw');
/** Draws the clef and dumps its position to allGlyphs. */
VF.Clef.prototype.draw = function() {
if (this.type == 'treble' || this.type == 'bass') {
const x = Math.round(this.getX() + this.getWidth() / 2);
allGlyphs.push(`glyph {
type: CLEF_${this.type.toUpperCase()}
x: ${x}
y_position: ${CLEF_LINE_FOR_OMR[this.type]}
}`);
}
drawClef.apply(this, arguments);
};
const drawStave = checkNotNull(VF.Stave.prototype.draw, 'Stave.draw');
/** Dumps the staff information. */
VF.Stave.prototype.draw = function() {
stafflineDistance = this.space(1);
const y = this.getYForLine(2);
const x0 = this.getX();
const x1 = this.getX() + this.getWidth();
staffCenterLine = `center_line {
x: ${x0}
y: ${y}
}
center_line {
x: ${x1}
y: ${y}
}
`;
drawStave.apply(this, arguments);
};
const drawNotehead = checkNotNull(VF.NoteHead.prototype.draw, 'Notehead.draw');
/** Draws the notehead and dumps its position to allGlyphs. */
VF.NoteHead.prototype.draw = function() {
// The notehead x seems to be the left end.
const x = Math.round(this.getAbsoluteX() + this.getWidth() / 2);
const y_position = vexflowLineToOMR(this.getLine());
allGlyphs.push(`glyph {
# TODO(ringw): NOTEHEAD_FILLED vs NOTEHEAD_EMPTY.
type: NOTEHEAD_FILLED
x: ${x}
y_position: ${y_position}
}`);
drawNotehead.apply(this, arguments);
};
const ACCIDENTAL_TYPES = {
'b': 'FLAT',
'#': 'SHARP',
'n': 'NATURAL'
};
const drawAccidental =
checkNotNull(VF.Accidental.prototype.draw, 'Accidental.draw');
/** Draws the accidental and dumps its position to allGlyphs. */
VF.Accidental.prototype.draw = function() {
if (this.type in ACCIDENTAL_TYPES) {
const note_start = this.note.getModifierStartXY(this.position, this.index);
// The modifier x (note_start.x + this.x_shift) seems to be the right end of
// the glyph.
const x = Math.round(note_start.x + this.x_shift - this.getWidth() / 2);
const y = note_start.y + this.y_shift;
const y_position = absoluteYToOMR(this.note.getStave(), y);
allGlyphs.push(`glyph {
type: ${ACCIDENTAL_TYPES[this.type]}
x: ${x}
y_position: ${y_position}
}`);
}
drawAccidental.apply(this, arguments);
};
/**
* @param {!Random} random The random generator
* @param {!object} probs Map from key to probability. The probability
* values must sum to 1.
* @return {!object} The sampled key from probs.
*/
function discreteSample(random, probs) {
let cumulativeProb = 0;
const randomUniform = random.real(0, 1);
for (let key of Object.keys(probs)) {
if (randomUniform < cumulativeProb + probs[key]) {
return key;
}
cumulativeProb += probs[key];
}
throw ValueError('Probabilities sum to ' + cumulativeProb);
}
const PROB_LEDGER_NOTE = 0.1;
const PROB_MODIFIERS = {
'#': 0.25,
'b': 0.25,
'##': 0.02,
'bb': 0.02,
'n': 0.15,
'': 0.31,
};
class Clef {
/**
* Samples a random note to display for the clef.
* @param {!Random} random The random generator
* @return {string} The note name, with accidental.
*/
genNote(random) {
const modifier = discreteSample(random, PROB_MODIFIERS);
return random.pick(this.baseNotes_()).replace('(?=[0-9])', modifier);
}
// TODO(ringw): Why does the bass clef render notes in the same positions
// as a treble clef? Fix this and add a different range of notes for bass.
/**
* @return {!array} The base note names (without accidentals) for
* notes that lie on the staff, or are within 2 ledger lines.
* @private
*/
baseNotes_() {
return [
'A3', 'B3', 'C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5', 'D5', 'E5',
'F5', 'G5', 'A5', 'B5', 'C6'
];
}
}
class TrebleClef extends Clef {
/** @return {string} the name used by VexFlow for the clef. */
name() {
return 'treble';
}
}
class BassClef extends Clef {
/** @return {string} the name used by VexFlow for the clef. */
name() {
return 'bass';
}
}
const CLEFS = [new TrebleClef(), new BassClef()];
if (!args.random_seeds) {
throw Error('--random_seeds is required');
}
const seedStrings = args.random_seeds.split(',');
const seeds = [];
seedStrings.forEach(function(seedString) {
const seed = parseInt(seedString, 10);
if (isNaN(seed)) {
throw Error('Seed is not an integer: ' + seedString);
}
seeds.push(seed);
});
global.document = new jsdom.JSDOM('
').window.document;
const vf = new Vex.Flow.Factory(
{renderer: {elementId: 'vexflow-div', width: 500, height: 200}});
const staveConstructor = vf.Stave;
// TODO(ringw): Support passing Vex.Flow.Stave options through addStave(), to
// avoid overriding "Stave" here.
/**
* @param {Object} params Staff params map
* @return {!vf.Stave} New Stave instance
*/
vf.Stave = function(params) {
const paramsCopy = {};
Object.assign(paramsCopy, params);
const options = {fill_style: '#000000'};
Object.assign(options, paramsCopy.options);
paramsCopy.options = options;
return staveConstructor.apply(this, [paramsCopy]);
};
pages = [];
seeds.forEach(function(seed) {
const random = new Random(Random.engines.mt19937().seed(seed));
const clef = random.pick(CLEFS);
const score = vf.EasyScore();
const system = vf.System();
const notes = [];
for (let i = 0; i < 4; i++) {
notes.push(clef.genNote(random));
}
// TODO(ringw): Random durations.
notes[0] = notes[0] + '/q';
system
.addStave({
voices: [score.voice(score.notes(notes.join(', ')))],
})
.addClef(clef.name())
.addTimeSignature('4/4');
vf.draw();
let page_message = `staffline_distance: ${stafflineDistance}
${staffCenterLine}
`;
allGlyphs.forEach(function(glyph) {
page_message = page_message + glyph;
});
pages.push({
'svg': document.getElementById('vexflow-div').innerHTML,
'page': page_message
});
resetPageState();
});
process.stdout.write(JSON.stringify(pages));
================================================
FILE: moonlight/training/generation/vexflow_generator_pipeline.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A pipeline for generating labeled patch data from VexFlow.
See `generation.py` for details on the output data.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import apache_beam as beam
from moonlight.pipeline import pipeline_flags
from moonlight.training.generation import generation
from moonlight.training.generation import image_noise
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_integer('num_positive_examples', 1000000,
'The number of positive examples to generate.')
flags.DEFINE_string('examples_path', '', 'The path of the output examples.')
flags.DEFINE_integer('num_shards', None, 'Fixed number of shards (optional)')
flags.DEFINE_multi_string('vexflow_generator_command', [
'/usr/bin/env',
'node',
'vexflow_generator.js',
], 'Command line to run the node.js vexflow generator.')
flags.DEFINE_float('negative_to_positive_example_ratio', 10,
'Ratio of negative to positive examples.')
flags.DEFINE_integer('num_pages_per_batch', 100,
'The number of pages to emit in every node.js run.')
flags.DEFINE_multi_string(
'svg_to_png_command', [
'/usr/bin/env',
'convert',
'svg:-',
'png:-',
], 'Command line to convert a SVG (stdin) to PNG (stdout).')
flags.DEFINE_integer('patch_width', 15, 'Width of a staffline patch.')
flags.DEFINE_integer(
'negative_example_distance', 3,
'The minimum distance of a negative example from any glyph.')
def main(_):
with pipeline_flags.create_pipeline() as pipeline:
num_pages = (FLAGS.num_positive_examples +
generation.POSITIVE_EXAMPLES_PER_IMAGE -
1) // generation.POSITIVE_EXAMPLES_PER_IMAGE
num_batches = (num_pages + FLAGS.num_pages_per_batch -
1) // FLAGS.num_pages_per_batch
batch_nums = pipeline | beam.transforms.Create(list(range(num_batches)))
pages = batch_nums | beam.ParDo(
generation.PageGenerationDoFn(
num_pages_per_batch=FLAGS.num_pages_per_batch,
vexflow_generator_command=FLAGS.vexflow_generator_command,
svg_to_png_command=FLAGS.svg_to_png_command))
def noise_fn(image):
# TODO(ringw): Add better noise, maybe using generative adversarial
# networks trained on real scores from IMSLP.
return image_noise.gaussian_noise(image_noise.random_rotation(image))
examples = pages | beam.ParDo(
generation.PatchExampleDoFn(
negative_example_distance=FLAGS.negative_example_distance,
patch_width=FLAGS.patch_width,
negative_to_positive_example_ratio=FLAGS
.negative_to_positive_example_ratio,
noise_fn=noise_fn))
examples |= beam.io.WriteToTFRecord(
FLAGS.examples_path,
beam.coders.ProtoCoder(tf.train.Example),
num_shards=FLAGS.num_shards)
if __name__ == '__main__':
app.run(main)
================================================
FILE: moonlight/util/BUILD
================================================
# Description:
# General utilities for OMR.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "util",
deps = [
":functional_ops",
":memoize",
":more_iter_tools",
":patches",
":run_length",
":segments",
],
)
py_library(
name = "functional_ops",
srcs = ["functional_ops.py"],
srcs_version = "PY2AND3",
deps = [], # tensorflow dep
)
py_test(
name = "functional_ops_test",
srcs = ["functional_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":functional_ops",
# disable_tf2
# tensorflow dep
],
)
py_library(
name = "memoize",
srcs = ["memoize.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "more_iter_tools",
srcs = ["more_iter_tools.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "more_iter_tools_test",
srcs = ["more_iter_tools_test.py"],
srcs_version = "PY2AND3",
deps = [
":more_iter_tools",
# absl dep
# absl/testing dep
# numpy dep
# six dep
],
)
py_library(
name = "patches",
srcs = ["patches.py"],
srcs_version = "PY2AND3",
deps = [], # tensorflow dep
)
py_test(
name = "patches_test",
srcs = ["patches_test.py"],
srcs_version = "PY2AND3",
deps = [
":patches",
# disable_tf2
# tensorflow dep
],
)
py_library(
name = "run_length",
srcs = ["run_length.py"],
srcs_version = "PY2AND3",
deps = [
# tensorflow dep
# tensorflow.contrib.image py dep
],
)
py_test(
name = "run_length_test",
srcs = ["run_length_test.py"],
srcs_version = "PY2AND3",
deps = [
":run_length",
# disable_tf2
# tensorflow dep
],
)
py_library(
name = "segments",
srcs = ["segments.py"],
srcs_version = "PY2AND3",
deps = [
# enum34 dep
# tensorflow dep
],
)
py_test(
name = "segments_test",
size = "small",
srcs = ["segments_test.py"],
srcs_version = "PY2AND3",
deps = [
":segments",
# disable_tf2
# absl/testing dep
# numpy dep
# tensorflow dep
],
)
================================================
FILE: moonlight/util/functional_ops.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functional op helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def flat_map_fn(fn, elems, dtype=None):
"""Flat maps `fn` on items unpacked from `elems` on dimension 0.
Analogous to `tf.map_fn`, but concatenates the result(s) along dimension 0.
Args:
fn: The callable to be performed.
elems: The tensor of elements to apply `fn` to.
dtype: The dtype of the output of `fn`.
Returns:
A tensor with the same rank as the input, and same dimensions except for
dimension 0. The function results for each element, concatenated along
dimension 0.
"""
elems = tf.convert_to_tensor(elems)
n = tf.shape(elems)[0]
zero_elem = tf.zeros(tf.shape(elems)[1:], elems.dtype)
dummy_output = fn(zero_elem)
output_elem_shape = tf.shape(dummy_output)[1:]
initial_results = tf.zeros(
tf.concat([[0], output_elem_shape], axis=0), dtype=dtype or elems.dtype)
def compute(i, results):
elem_results = fn(elems[i])
return i + 1, tf.concat([results, elem_results], axis=0)
return tf.while_loop(
lambda i, _: i < n,
compute, [0, initial_results],
shape_invariants=[tf.TensorShape(()),
tf.TensorShape(None)],
parallel_iterations=1)[1]
================================================
FILE: moonlight/util/functional_ops_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the functional ops helpers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import functional_ops
class FunctionalOpsTest(tf.test.TestCase):
def testFlatMap(self):
with self.test_session():
items = functional_ops.flat_map_fn(tf.range, [1, 3, 0, 5])
self.assertAllEqual(items.eval(), [0, 0, 1, 2, 0, 1, 2, 3, 4])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/util/memoize.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple memoizer for a function/method/property."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class MemoizedFunction(object):
"""Decorates a function to be memoized.
Caches all invocations of the function with unique arguments. The arguments
must be hashable.
Decorated functions are not threadsafe. This decorator is currently used for
TensorFlow graph construction, which happens in a single thread.
"""
def __init__(self, function):
self._function = function
self._results = {}
def __call__(self, *args):
"""Calls the function to be memoized.
Args:
*args: The args to pass through to the function. Keyword arguments are not
supported.
Raises:
TypeError if an argument is unhashable.
Returns:
The memoized return value of the wrapped function. The return value will
be computed exactly once for each unique argument tuple.
"""
if args in self._results:
return self._results[args]
self._results[args] = self._function(*args)
return self._results[args]
================================================
FILE: moonlight/util/more_iter_tools.py
================================================
"""More iterator utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
def iter_sample(iterator, count, rand=None):
"""Performs reservoir sampling on an iterator.
The output is a list. The entire iterator must be read in one shot to
determine any output element, and `count` elements need to be stored in
memory.
Args:
iterator: An iterator/generator.
count: The number of elements to sample.
rand: Optional random object which is already seeded.
Returns:
A list with length `count`, or the contents of `iterator` if smaller.
"""
rand = rand or random.Random()
result = []
for index, elem in enumerate(iterator):
# Fill the result with count elements.
if index < count:
result.append(elem)
# Replace an existing element uniformly randomly, but the probability of
# replacing any element is steadily decreasing.
random_index = rand.randint(0, index)
if random_index < count:
result[random_index] = elem
return result
================================================
FILE: moonlight/util/more_iter_tools_test.py
================================================
"""Tests for more_iter_tools."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
from absl.testing import absltest
from moonlight.util import more_iter_tools
import numpy as np
from six import moves
class MoreIterToolsTest(absltest.TestCase):
def testSample_count_0(self):
self.assertEqual([], more_iter_tools.iter_sample(moves.range(100), 0))
def testSample_iter_empty(self):
self.assertEqual([], more_iter_tools.iter_sample(moves.range(0), 10))
def testSample_distribution(self):
sample = more_iter_tools.iter_sample(
moves.range(0, 100000), 9999, rand=random.Random(12345))
self.assertEqual(9999, len(sample))
# Create a histogram with 10 bins.
bins = np.bincount([elem // 10000 for elem in sample])
self.assertEqual(10, len(bins))
# Samples should be distributed roughly uniformly into bins.
expected_bin_count = 9999 // 10
for bin_count in bins:
self.assertTrue(
np.allclose(bin_count, expected_bin_count, rtol=0.1),
'{} within 10% of {}'.format(bin_count, expected_bin_count))
if __name__ == '__main__':
absltest.main()
================================================
FILE: moonlight/util/patches.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extracts patches from image(s)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
def patches_1d(images, patch_width):
"""Extract patches along the last dimension of `images`.
Thin wrapper around `tf.extract_image_patches` that only takes horizontal
slices of the input, and reshapes the output to N-D images.
Args:
images: The image(s) to extract patches from. Shape at least 2D, with shape
images_shape + (height, width). The image height and width must be
statically known (set on `images.get_shape()`).
patch_width: Width of a patch. int or long (must be statically known).
Returns:
Patches extracted from each image in `images`. Shape
images_shape + (width - patch_width + 1, height, width).
Raises:
ValueError: If patch_width is not an int or long.
"""
if not isinstance(patch_width, six.integer_types):
raise ValueError("patch_width must be an integer")
# The shape of the input, excluding image height and width.
images_shape = tf.shape(images)[:-2]
# num_images is not necessary to know at graph creation time.
num_images = tf.reduce_prod(images_shape)
# image_height must be statically known for tf.extract_image_patches.
image_height = int(images.get_shape()[-2])
image_width = tf.shape(images)[-1]
def do_extract_patches():
"""Returns the image patches, assuming images.shape[0] > 0."""
# patch_width must be an int, not a Tensor.
# Reshape to (num_images, height, width, channels).
images_nhwc = tf.reshape(
images, tf.stack([num_images, image_height, image_width, 1]))
patches = tf.extract_image_patches(
images_nhwc, [1, image_height, patch_width, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="VALID")
patches_shape = tf.concat(
[images_shape, [tf.shape(patches)[2], image_height, patch_width]],
axis=0)
return tf.reshape(patches, patches_shape)
def empty_patches():
return tf.zeros(
tf.concat([images_shape, [0, image_height, patch_width]], axis=0),
images.dtype)
return tf.cond(tf.greater(num_images, 0), do_extract_patches, empty_patches)
================================================
FILE: moonlight/util/patches_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the patches utility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import patches
from six import moves
class PatchesTest(tf.test.TestCase):
def test2D(self):
image_t = tf.random_uniform((100, 200))
image_t.set_shape((100, 200))
patch_width = 10
patches_t = patches.patches_1d(image_t, patch_width)
with self.test_session() as sess:
image_arr, patches_arr = sess.run((image_t, patches_t))
self.assertEqual(patches_arr.shape,
(200 - patch_width + 1, 100, patch_width))
for i in moves.xrange(patches_arr.shape[0]):
self.assertAllEqual(patches_arr[i], image_arr[:, i:i + patch_width])
def test4D(self):
height = 15
width = 20
image_t = tf.random_uniform((4, 8, height, width))
image_t.set_shape((None, None, height, width))
patch_width = 10
patches_t = patches.patches_1d(image_t, patch_width)
with self.test_session() as sess:
image_arr, patches_arr = sess.run((image_t, patches_t))
self.assertEqual(patches_arr.shape,
(4, 8, width - patch_width + 1, height, patch_width))
for i in moves.xrange(patches_arr.shape[0]):
for j in moves.xrange(patches_arr.shape[1]):
for k in moves.xrange(patches_arr.shape[2]):
self.assertAllEqual(patches_arr[i, j, k],
image_arr[i, j, :, k:k + patch_width])
def testEmpty(self):
height = 15
width = 20
image_t = tf.zeros((1, 2, 0, 3, 4, height, width))
patch_width = 10
patches_t = patches.patches_1d(image_t, patch_width)
with self.test_session():
self.assertEqual(patches_t.eval().shape,
(1, 2, 0, 3, 4, 0, height, patch_width))
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/util/run_length.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image run length encoding.
Each run is a subsequence of consecutive pixels with the same value. The
run-length encoding is the list of all runs in order, with their lengths and
values.
See: https://en.wikipedia.org/wiki/Run-length_encoding
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import image as contrib_image
def vertical_run_length_encoding(image):
"""Returns the runs in each column of the image.
A run is a subsequence of consecutive pixels that all have the same value.
Internally, we treat the image as batches of single-column images in order to
use connected component analysis.
Args:
image: A 2D image.
Returns:
The column index of each vertical run.
The value in the image for each vertical run.
The length of each vertical run.
"""
with tf.name_scope('run_length_encoding'):
image = tf.convert_to_tensor(image, name='image', dtype=tf.bool)
# Set arbitrary, distinct, nonzero values for True and False pixels.
# True pixels map to 2, and False pixels map to 1.
# Transpose the image, and insert an extra dimension. This creates a batch
# of "images" for connected component analysis, where each image is a single
# column of the original image. Therefore, the connected components are
# actually runs from a single column.
components = contrib_image.connected_components(
tf.to_int32(tf.expand_dims(tf.transpose(image), axis=1)) + 1)
# Flatten in order to use with unsorted segment ops.
flat_components = tf.reshape(components, [-1])
num_components = tf.maximum(0, tf.reduce_max(components) + 1)
# Get the column index corresponding to each pixel present in
# flat_components.
column_indices = tf.reshape(
tf.tile(
# Count 0 through `width - 1` on axis 0, then repeat each element
# `height` times.
tf.expand_dims(tf.range(tf.shape(image)[1]), axis=1),
multiples=[1, tf.shape(image)[0]]),
# pyformat: disable
[-1])
# Take the column index for each component. For each component index k,
# we want any entry of column_indices where the corresponding entry in
# flat_components is k. column_indices should be the same for all pixels in
# the same component, so we can just take the max of all of them. Disregard
# component 0, which just represents all of the zero pixels across the
# entire array (should be empty, because we pass in a nonzero image).
component_columns = tf.unsorted_segment_max(column_indices, flat_components,
num_components)[1:]
# Take the original value of each component. Again, the value should be the
# same for all pixels in a single component, so we can just take the max.
component_values = tf.unsorted_segment_max(
tf.to_int32(tf.reshape(tf.transpose(image), [-1])), flat_components,
num_components)[1:]
# Take the length of each component (run), by counting the number of pixels
# in the component.
component_lengths = tf.to_int32(tf.bincount(flat_components)[1:])
return component_columns, component_values, component_lengths
================================================
FILE: moonlight/util/run_length_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for run length encoding."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import run_length
class RunLengthTest(tf.test.TestCase):
def testEmpty(self):
with self.test_session() as sess:
columns, values, lengths = sess.run(
run_length.vertical_run_length_encoding(tf.zeros((0, 0), tf.bool)))
self.assertAllEqual(columns, [])
self.assertAllEqual(values, [])
self.assertAllEqual(lengths, [])
def testBooleanImage(self):
img = tf.cast(
[
[0, 0, 1, 0, 0, 1],
# pyformat: disable
[1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 1, 1, 0, 1, 0]
],
tf.bool)
with self.test_session() as sess:
columns, values, lengths = sess.run(
run_length.vertical_run_length_encoding(img))
self.assertAllEqual(columns,
[0] * 3 + [1] * 4 + [2] + [3] * 3 + [4] * 2 + [5] * 2)
self.assertAllEqual(values, [0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0])
self.assertAllEqual(lengths, [1, 1, 2, 1, 1, 1, 1, 4, 1, 2, 1, 1, 3, 3, 1])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/util/segments.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Segment/run length utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
import tensorflow as tf
class SegmentsMode(enum.Enum):
"""The valid modes for segmentation."""
# Return the start position of each segment
STARTS = 1
# Return the floored center position of each segment
CENTERS = 2
def true_segments_1d(segments,
mode=SegmentsMode.CENTERS,
max_gap=0,
min_length=0,
name=None):
"""Labels contiguous True runs in segments.
Args:
segments: 1D boolean tensor.
mode: The SegmentsMode. Returns the start of each segment (STARTS), or the
rounded center of each segment (CENTERS).
max_gap: Fill gaps of length at most `max_gap` between true segments. int.
min_length: Minimum length of a returned segment. int.
name: Optional name for the op.
Returns:
run_centers: int32 tensor. Depending on `mode`, either the start of each
True run, or the (rounded) center of each True run.
run_lengths: int32; the lengths of each True run.
"""
with tf.name_scope(name, "true_segments", [segments]):
segments = tf.convert_to_tensor(segments, tf.bool)
run_starts, run_lengths = _segments_1d(segments, mode=SegmentsMode.STARTS)
# Take only the True runs. After whichever run is True first, the True runs
# are every other run.
first_run = tf.cond(
# First value is False, or all values are False. Handles empty segments
# correctly.
tf.logical_or(tf.reduce_any(segments[0:1]), ~tf.reduce_any(segments)),
lambda: tf.constant(0),
lambda: tf.constant(1))
num_runs = tf.shape(run_starts)[0]
run_nums = tf.range(num_runs)
is_true_run = tf.equal(run_nums % 2, first_run % 2)
# Find gaps between True runs that can be merged.
is_gap = tf.logical_and(
tf.not_equal(run_nums % 2, first_run % 2),
tf.logical_and(
tf.greater(run_nums, first_run), tf.less(run_nums, num_runs - 1)))
fill_gap = tf.logical_and(is_gap, tf.less_equal(run_lengths, max_gap))
# Segment the consecutive runs of True or False values based on whether they
# are True, or are a gap of False values that can be bridged. Then, flatten
# the runs of runs.
runs_to_merge = tf.logical_or(is_true_run, fill_gap)
run_of_run_starts, _ = _segments_1d(runs_to_merge, mode=SegmentsMode.STARTS)
# Get the start of every new run from the original run starts.
merged_run_starts = tf.gather(run_starts, run_of_run_starts)
# Make an array mapping the original runs to their run of runs. Increment
# the number for every run of run start except for the first one, so that
# the array has values from 0 to num_run_of_runs.
merged_run_inds = tf.cumsum(
tf.sparse_to_dense(
sparse_indices=tf.cast(run_of_run_starts[1:, None], tf.int64),
output_shape=tf.cast(num_runs[None], tf.int64),
sparse_values=tf.ones_like(run_of_run_starts[1:])))
# Sum the lengths of the original runs that were merged.
merged_run_lengths = tf.segment_sum(run_lengths, merged_run_inds)
if mode is SegmentsMode.CENTERS:
merged_starts_or_centers = (
merged_run_starts + tf.floordiv(merged_run_lengths - 1, 2))
else:
merged_starts_or_centers = merged_run_starts
# If there are no true values, increment first_run to 1, so we will skip
# the single (false) run.
first_run += tf.to_int32(tf.logical_not(tf.reduce_any(segments)))
merged_starts_or_centers = merged_starts_or_centers[first_run::2]
merged_run_lengths = merged_run_lengths[first_run::2]
# Only take segments at least min_length long.
is_long_enough = tf.greater_equal(merged_run_lengths, min_length)
is_long_enough.set_shape([None])
merged_starts_or_centers = tf.boolean_mask(merged_starts_or_centers,
is_long_enough)
merged_run_lengths = tf.boolean_mask(merged_run_lengths, is_long_enough)
return merged_starts_or_centers, merged_run_lengths
def _segments_1d(values, mode, name=None):
"""Labels consecutive runs of the same value.
Args:
values: 1D tensor of any type.
mode: The SegmentsMode. Returns the start of each segment (STARTS), or the
rounded center of each segment (CENTERS).
name: Optional name for the op.
Returns:
run_centers: int32 tensor; the centers of each run with the same consecutive
values.
run_lengths: int32 tensor; the lengths of each run.
Raises:
ValueError: if mode is not recognized.
"""
with tf.name_scope(name, "segments", [values]):
def do_segments(values):
"""Actually does segmentation.
Args:
values: 1D tensor of any type. Non-empty.
Returns:
run_centers: int32 tensor
run_lengths: int32 tensor
Raises:
ValueError: if mode is not recognized.
"""
length = tf.shape(values)[0]
values = tf.convert_to_tensor(values)
# The first run has id 0, so we don't increment the id.
# Otherwise, the id is incremented when the value changes.
run_start_bool = tf.concat(
[[False], tf.not_equal(values[1:], values[:-1])], axis=0)
# Cumulative sum the run starts to get the run ids.
segment_ids = tf.cumsum(tf.cast(run_start_bool, tf.int32))
if mode is SegmentsMode.STARTS:
run_centers = tf.segment_min(tf.range(length), segment_ids)
elif mode is SegmentsMode.CENTERS:
run_centers = tf.segment_mean(
tf.cast(tf.range(length), tf.float32), segment_ids)
run_centers = tf.cast(tf.floor(run_centers), tf.int32)
else:
raise ValueError("Unexpected mode: %s" % mode)
run_lengths = tf.segment_sum(tf.ones([length], tf.int32), segment_ids)
return run_centers, run_lengths
def empty_segments():
return (tf.zeros([0], tf.int32), tf.zeros([0], tf.int32))
return tf.cond(
tf.greater(tf.shape(values)[0], 0), lambda: do_segments(values),
empty_segments)
def peaks(values, minval=None, invalidate_distance=0, name=None):
"""Labels peaks in values.
Args:
values: 1D tensor of a numeric type.
minval: Minimum value which is considered a peak.
invalidate_distance: Invalidates nearby potential peaks. The peaks are
searched sequentially by descending value, and from left to right for
equal values. Once a peak is found in this order, it invalidates any peaks
yet to be seen that are <= invalidate_distance away. A distance of 0
effectively produces no invalidation.
name: Optional name for the op.
Returns:
peak_centers: The (rounded) centers of each peak, which are locations where
the value is higher than the value before and after. If there is a run
of equal values at the peak, the rounded center of the run is returned.
int32 1D tensor.
"""
with tf.name_scope(name, "peaks", [values]):
values = tf.convert_to_tensor(values, name="values")
invalidate_distance = tf.convert_to_tensor(
invalidate_distance, name="invalidate_distance", dtype=tf.int32)
# Segment the values and find local maxima.
# Take the center of each run of consecutive equal values.
segment_centers, _ = _segments_1d(values, mode=SegmentsMode.CENTERS)
segment_values = tf.gather(values, segment_centers)
# If we have zero or one segments, there are no peaks. Just use zeros as the
# edge values in that case.
first_val, second_val, penultimate_val, last_val = tf.cond(
# pyformat: disable
tf.greater_equal(tf.shape(segment_values)[0], 2),
lambda: tuple(segment_values[i] for i in (0, 1, -2, -1)),
lambda: tuple(tf.constant(0, values.dtype) for i in range(4)))
# Each segment must be greater than the segment before and after it.
segment_is_peak = tf.concat(
[[first_val > second_val],
tf.greater(segment_values[1:-1],
tf.maximum(segment_values[:-2], segment_values[2:])),
[last_val > penultimate_val]],
axis=0)
if minval is not None:
# Filter the peaks by minval.
segment_is_peak = tf.logical_and(segment_is_peak,
tf.greater_equal(segment_values, minval))
# Get the center coordinates of each peak, and sort by descending value.
all_peaks = tf.boolean_mask(segment_centers, segment_is_peak)
num_peaks = tf.shape(all_peaks)[0]
peak_values = tf.boolean_mask(segment_values, segment_is_peak)
_, peak_order = tf.nn.top_k(peak_values, k=num_peaks, sorted=True)
all_peaks = tf.gather(all_peaks, peak_order)
all_peaks.set_shape([None])
# Loop over the peaks, accepting one at a time and possibly invalidating
# other ones.
def loop_condition(_, current_peaks):
return tf.shape(current_peaks)[0] > 0
def loop_body(accepted_peaks, current_peaks):
peak = current_peaks[0]
remaining_peaks = current_peaks[1:]
keep_peaks = tf.greater(
tf.abs(remaining_peaks - peak), invalidate_distance)
remaining_peaks = tf.boolean_mask(remaining_peaks, keep_peaks)
return tf.concat([accepted_peaks, [peak]], axis=0), remaining_peaks
accepted_peaks = tf.while_loop(
loop_condition,
loop_body, [tf.zeros([0], all_peaks.dtype), all_peaks],
shape_invariants=[tf.TensorShape([None]),
tf.TensorShape([None])])[0]
# Sort the peaks by index.
# TODO(ringw): Add a tf.sort op that sorts in ascending order.
sorted_negative_peaks, _ = tf.nn.top_k(
-accepted_peaks, k=tf.shape(accepted_peaks)[0], sorted=True)
return -sorted_negative_peaks
================================================
FILE: moonlight/util/segments_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the segmentation utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
import numpy as np
import tensorflow as tf
from moonlight.util import segments
# TODO(ringw): Remove the multiple inheritance once tf.test.TestCase extends
# absltest.
class SegmentsTest(tf.test.TestCase, absltest.TestCase):
def test_true_segments_1d(self):
# Arbitrary boolean array to get True and False runs from.
values = tf.constant([True, True, True, False, True, False, True, True])
centers, lengths = segments.true_segments_1d(
values, mode=segments.SegmentsMode.CENTERS)
with self.test_session():
self.assertAllEqual(centers.eval(), [1, 4, 6])
self.assertAllEqual(lengths.eval(), [3, 1, 2])
starts, lengths = segments.true_segments_1d(
values, mode=segments.SegmentsMode.STARTS)
with self.test_session():
self.assertAllEqual(starts.eval(), [0, 4, 6])
self.assertAllEqual(lengths.eval(), [3, 1, 2])
def test_true_segments_1d_large(self):
# Arbitrary boolean array to get True and False runs from.
run_values = [False, True, False, True, False, True, False, True]
run_lengths = [3, 5, 2, 6, 4, 8, 7, 1]
values = tf.constant(np.repeat(run_values, run_lengths))
centers, lengths = segments.true_segments_1d(
values, mode=segments.SegmentsMode.CENTERS)
with self.test_session():
self.assertAllEqual(centers.eval(), [
sum(run_lengths[:1]) + (run_lengths[1] - 1) // 2,
sum(run_lengths[:3]) + (run_lengths[3] - 1) // 2,
sum(run_lengths[:5]) + (run_lengths[5] - 1) // 2,
sum(run_lengths[:7]) + (run_lengths[7] - 1) // 2
])
self.assertAllEqual(lengths.eval(), run_lengths[1::2])
starts, lengths = segments.true_segments_1d(
values, mode=segments.SegmentsMode.STARTS)
with self.test_session():
self.assertAllEqual(starts.eval(), [
sum(run_lengths[:1]),
sum(run_lengths[:3]),
sum(run_lengths[:5]),
sum(run_lengths[:7])
])
self.assertAllEqual(lengths.eval(), run_lengths[1::2])
def test_true_segments_1d_empty(self):
for mode in list(segments.SegmentsMode):
for max_gap in [0, 1]:
centers, lengths = segments.true_segments_1d([],
mode=mode,
max_gap=max_gap)
with self.test_session():
self.assertAllEqual(centers.eval(), [])
self.assertAllEqual(lengths.eval(), [])
def test_true_segments_1d_max_gap(self):
# Arbitrary boolean array to get True and False runs from.
values = tf.constant([
False, False,
True, True, True,
False, False,
True,
False, False, False, False, False, False,
True, True, True, True,
False,
True, True,
False, False,
True,
]) # pyformat: disable
centers, lengths = segments.true_segments_1d(values, max_gap=0)
with self.test_session():
self.assertAllEqual(centers.eval(), [3, 7, 15, 19, 23])
self.assertAllEqual(lengths.eval(), [3, 1, 4, 2, 1])
centers, lengths = segments.true_segments_1d(values, max_gap=1)
with self.test_session():
self.assertAllEqual(centers.eval(), [3, 7, 17, 23])
self.assertAllEqual(lengths.eval(), [3, 1, 7, 1])
for max_gap in range(2, 6):
centers, lengths = segments.true_segments_1d(values, max_gap=max_gap)
with self.test_session():
self.assertAllEqual(centers.eval(), [4, 18])
self.assertAllEqual(lengths.eval(), [6, 10])
centers, lengths = segments.true_segments_1d(values, max_gap=6)
with self.test_session():
self.assertAllEqual(centers.eval(), [12])
self.assertAllEqual(lengths.eval(), [22])
# TODO(ringw): Make these tests parameterized when absl is released.
def test_true_segments_1d_all_false_length_1(self):
self._test_true_segments_1d_all_false(1)
def test_true_segments_1d_all_false_length_2(self):
self._test_true_segments_1d_all_false(2)
def test_true_segments_1d_all_false_length_8(self):
self._test_true_segments_1d_all_false(8)
def test_true_segments_1d_all_false_length_11(self):
self._test_true_segments_1d_all_false(11)
def _test_true_segments_1d_all_false(self, length):
centers, lengths = segments.true_segments_1d(tf.zeros(length, tf.bool))
with self.test_session():
self.assertAllEqual(centers.eval(), [])
self.assertAllEqual(lengths.eval(), [])
def test_true_segments_1d_min_length_0(self):
self._test_true_segments_1d_min_length(0)
def test_true_segments_1d_min_length_1(self):
self._test_true_segments_1d_min_length(1)
def test_true_segments_1d_min_length_2(self):
self._test_true_segments_1d_min_length(2)
def test_true_segments_1d_min_length_3(self):
self._test_true_segments_1d_min_length(3)
def test_true_segments_1d_min_length_4(self):
self._test_true_segments_1d_min_length(4)
def test_true_segments_1d_min_length_5(self):
self._test_true_segments_1d_min_length(5)
def test_true_segments_1d_min_length_6(self):
self._test_true_segments_1d_min_length(6)
def _test_true_segments_1d_min_length(self, min_length):
# Arbitrary boolean array to get True and False runs from.
values = tf.constant([
False, False, False,
True,
False,
True, True,
False,
True,
False,
True, True, True, True,
False,
True, True,
]) # pyformat: disable
all_centers = np.asarray([3, 5, 8, 11, 15])
all_lengths = np.asarray([1, 2, 1, 4, 2])
expected_centers = all_centers[all_lengths >= min_length]
expected_lengths = all_lengths[all_lengths >= min_length]
centers, lengths = segments.true_segments_1d(values, min_length=min_length)
with self.test_session():
self.assertAllEqual(expected_centers, centers.eval())
self.assertAllEqual(expected_lengths, lengths.eval())
def test_peaks(self):
values = tf.constant([5, 3, 1, 1, 0, 1, 2, 3, 3, 3, 2, 3, 4, 1, 2])
with self.test_session():
self.assertAllEqual(segments.peaks(values).eval(), [0, 8, 12, 14])
self.assertAllEqual(segments.peaks(values, minval=3).eval(), [0, 8, 12])
def test_peaks_empty(self):
with self.test_session():
self.assertAllEqual(segments.peaks([]).eval(), [])
def test_peaks_invalidate_distance(self):
values = tf.constant([0, 0, 10, 0, 5, 3, 2, 1, 2, 3, 8, 8, 7, 8])
with self.test_session():
self.assertAllEqual(
segments.peaks(values, invalidate_distance=0).eval(), [2, 4, 10, 13])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=1).eval(), [2, 4, 10, 13])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=2).eval(), [2, 10, 13])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=3).eval(), [2, 10])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=4).eval(), [2, 10])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=7).eval(), [2, 10])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=8).eval(), [2, 13])
self.assertAllEqual(
segments.peaks(values, invalidate_distance=99).eval(), [2])
def test_peaks_array_filled_with_same_value(self):
for value in (0, 42, 4.2):
arr = tf.fill([100], value)
with self.test_session():
self.assertEmpty(segments.peaks(arr).eval())
def test_peaks_one_segment(self):
values = tf.constant([0, 0, 0, 0, 3, 0, 0, 0, 0])
with self.test_session():
self.assertAllEqual(segments.peaks(values).eval(), [4])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/vision/BUILD
================================================
# Description:
# General computer vision routines for OMR.
package(
default_visibility = ["//moonlight:__subpackages__"],
)
licenses(["notice"]) # Apache 2.0
py_library(
name = "vision",
deps = [
":hough",
":images",
":morphology",
],
)
py_library(
name = "images",
srcs = ["images.py"],
srcs_version = "PY2AND3",
deps = [
# tensorflow dep
# tensorflow.contrib.image py dep
],
)
py_test(
name = "images_test",
srcs = ["images_test.py"],
srcs_version = "PY2AND3",
deps = [
":images",
# disable_tf2
# tensorflow dep
],
)
py_library(
name = "hough",
srcs = ["hough.py"],
srcs_version = "PY2AND3",
deps = [
"//moonlight/util:segments",
# tensorflow dep
],
)
py_test(
name = "hough_test",
srcs = ["hough_test.py"],
srcs_version = "PY2AND3",
deps = [
":hough",
# disable_tf2
# numpy dep
# tensorflow dep
],
)
py_library(
name = "morphology",
srcs = ["morphology.py"],
srcs_version = "PY2AND3",
deps = [
":images",
# tensorflow dep
],
)
py_test(
name = "morphology_test",
srcs = ["morphology_test.py"],
srcs_version = "PY2AND3",
deps = [
":morphology",
# disable_tf2
# numpy dep
# tensorflow dep
],
)
================================================
FILE: moonlight/vision/hough.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hough transform for line detection.
Transforms a boolean image to the Hough space, where each entry corresponds to a
line parameterized by the angle `theta` clockwise from vertical (in radians),
and the distance `rho` (in pixels; the distance from coordinate `(0, 0)` in the
image to the closest point in the line).
For performance, the image should be sparse, containing mostly False elements,
because `tf.where(image)` will be called.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.util import segments
def hough_lines(image, thetas):
"""Hough transform of a boolean image.
Args:
image: The image. 2D boolean tensor. Should be sparse (mostly Falses).
thetas: 1D float32 tensor of possible angles from the vertical for the line.
Returns:
The Hough space for the image. Shape `(num_theta, num_rho)`, where `num_rho`
is `sqrt(height**2 + width**2)`.
"""
coords = tf.cast(tf.where(image), thetas.dtype)
rho = tf.cast(
# x cos theta + y sin theta
tf.expand_dims(coords[:, 1], 0) * tf.cos(thetas)[:, None] +
tf.expand_dims(coords[:, 0], 0) * tf.sin(thetas)[:, None],
tf.int32)
height = tf.cast(tf.shape(image)[0], tf.float64)
width = tf.cast(tf.shape(image)[1], tf.float64)
num_rho = tf.cast(tf.ceil(tf.sqrt(height * height + width * width)), tf.int32)
hough_bins = _bincount_2d(rho, num_rho)
return hough_bins
def hough_peaks(hough_bins, thetas, minval=0, invalidate_distance=0):
"""Finds the peak lines in Hough space.
Args:
hough_bins: Hough bins returned by `hough_lines`.
thetas: Angles; argument given to `hough_lines`.
minval: Minimum vote count for a Hough bin to be considered. int or float.
invalidate_distance: When selecting a line `(rho, theta)`, invalidate all
lines with the same theta and `+- invalidate_distance` from `rho`.
int32. Caveat: this should only be used if all theta values are similar.
If thetas cover a wide range, this will invalidate lines that might
not even intersect.
Returns:
Tensor of peak rho indices (int32).
Tensor of peak theta values (float32).
"""
thetas = tf.convert_to_tensor(thetas)
bin_score_dtype = thetas.dtype # floating point score derived from hough_bins
minval = tf.convert_to_tensor(minval)
if minval.dtype.is_floating:
minval = tf.ceil(minval)
invalidate_distance = tf.convert_to_tensor(
invalidate_distance, dtype=tf.int32)
# Choose the theta with the highest bin value for each rho.
selected_theta_ind = tf.argmax(hough_bins, axis=0)
# Take the Hough bin value for each rho and the selected theta.
hough_bins = tf.gather_nd(
hough_bins,
tf.stack([
tf.cast(selected_theta_ind, tf.int32),
tf.range(tf.shape(hough_bins)[1])
],
axis=1))
# hough_bins are integers. Subtract a penalty (< 1) for lines that are not
# horizontal or vertical, so that we break ties in favor of the more
# horizontal or vertical line.
infinitesimal = tf.constant(1e-10, bin_score_dtype)
# Decrease minval so we don't discard bins that are penalized, if they
# originally equalled minval.
minval = tf.cast(minval, bin_score_dtype) - infinitesimal
selected_thetas = tf.gather(thetas, selected_theta_ind)
# min(|sin(t)|, |cos(t)|) is 0 for horizontal and vertical angles, and between
# 0 and 1 otherwise.
penalty = tf.multiply(
tf.minimum(
tf.abs(tf.sin(selected_thetas)), tf.abs(tf.cos(selected_thetas))),
infinitesimal)
bin_score = tf.cast(hough_bins, bin_score_dtype) - penalty
# Find the peaks in the 1D hough_bins array.
peak_rhos = segments.peaks(
bin_score, minval=minval, invalidate_distance=invalidate_distance)
# Get the actual angles for each selected peak.
peak_thetas = tf.gather(thetas, tf.gather(selected_theta_ind, peak_rhos))
return peak_rhos, peak_thetas
def _bincount_2d(values, num_values):
"""Bincounts each row of values.
Args:
values: The values to bincount. 2D integer tensor.
num_values: The number of columns of the output. Entries in `values` that
are `>= num_values` will be ignored.
Returns:
The bin counts. Shape `(values.shape[0], num_values)`. The `i`th row
contains the result of
`tf.bincount(values[i, :], maxlength=num_values)`.
"""
num_rows = tf.shape(values)[0]
# Convert the values in each row to a consecutive range of ids that will not
# overlap with the other rows.
row_values = values + tf.range(num_rows)[:, None] * num_values
# Remove entries that would collide with other rows.
values_flat = tf.boolean_mask(row_values,
(0 <= values) & (values < num_values))
bins_length = num_rows * num_values
bins = tf.bincount(values_flat, minlength=bins_length, maxlength=bins_length)
return tf.reshape(bins, [num_rows, num_values])
================================================
FILE: moonlight/vision/hough_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Hough transform."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.vision import hough
class HoughTest(tf.test.TestCase):
def testHorizontalLines(self):
image = np.asarray(
[[0, 0, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable
thetas = np.asarray([np.pi / 2, np.pi / 4, 0, -np.pi / 4])
with self.test_session() as sess:
hough_bins = sess.run(hough.hough_lines(image, thetas))
self.assertAllEqual(
hough_bins,
# theta pi/2 gives the horizontal projection (sum each row).
[[0, 5, 0, 0, 7, 0, 4, 0, 0, 0, 0],
# theta pi/4 rotates the lines counter-clockwise from horizontal, and
# higher rho values go down and right into the image.
[0, 1, 3, 2, 5, 2, 2, 1, 0, 0, 0],
# theta 0 gives the vertical projection (sum each column).
[2, 3, 3, 3, 2, 2, 1, 0, 0, 0, 0],
# theta -pi/4 rotates the lines counter-clockwise from vertical, and
# higher rho values go up and right away from the image.
[5, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable
def testHoughPeaks_verticalLines(self):
image = np.asarray(
[[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 0, 1, 0]]) # pyformat: disable
# Test the full range of angles.
thetas = np.linspace(-np.pi, np.pi, 101)
hough_bins = hough.hough_lines(image, thetas)
peak_rho_t, peak_theta_t = hough.hough_peaks(hough_bins, thetas)
with self.test_session() as sess:
peak_rho, peak_theta = sess.run((peak_rho_t, peak_theta_t))
# Vertical line
self.assertEqual(peak_rho[0], 1)
self.assertAlmostEqual(peak_theta[0], 0)
# Rotated line
self.assertEqual(peak_rho[1], 3)
self.assertAlmostEqual(peak_theta[1], -np.pi / 8, places=1)
def testHoughPeaks_minval(self):
image = np.asarray(
[[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]]) # pyformat: disable
thetas = np.linspace(0, np.pi / 2, 17)
hough_bins = hough.hough_lines(image, thetas)
peak_rho_t, peak_theta_t = hough.hough_peaks(hough_bins, thetas, minval=2)
with self.test_session() as sess:
peak_rho, peak_theta = sess.run((peak_rho_t, peak_theta_t))
self.assertEqual(peak_rho.shape, (1,))
self.assertEqual(peak_theta.shape, (1,))
def testHoughPeaks_minvalTooLarge(self):
image = np.asarray(
[[0, 0, 0, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]]) # pyformat: disable
thetas = np.linspace(0, np.pi / 2, 17)
hough_bins = hough.hough_lines(image, thetas)
peak_rho_t, peak_theta_t = hough.hough_peaks(hough_bins, thetas, minval=3.1)
with self.test_session() as sess:
peak_rho, peak_theta = sess.run((peak_rho_t, peak_theta_t))
self.assertEqual(peak_rho.shape, (0,))
self.assertEqual(peak_theta.shape, (0,))
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/vision/images.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import image as contrib_image
# TODO(ringw): Replace once github.com/tensorflow/tensorflow/pull/10748
# is in.
def translate(image, x, y):
"""Translates the image.
Args:
image: A 2D float32 tensor.
x: The x shift of the output, in pixels.
y: The y shift of the output, in pixels.
Returns:
The translated image tensor.
"""
# TODO(ringw): Fix mixing scalar constants and scalar tensors here.
one = tf.constant(1, tf.float32)
zero = tf.constant(0, tf.float32)
# The inverted transformation matrix expected by tf.contrib.image.transform.
# The last entry is the 3x3 matrix is left out and is always 1.
translation_matrix = tf.convert_to_tensor(
[one, zero, tf.to_float(-x),
zero, one, tf.to_float(-y),
zero, zero], tf.float32) # pyformat: disable
return contrib_image.transform(image, translation_matrix)
================================================
FILE: moonlight/vision/images_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.vision import images
class ImagesTest(tf.test.TestCase):
def testTranslate(self):
with self.test_session():
arr = tf.reshape(tf.range(9), (3, 3))
self.assertAllEqual(
images.translate(arr, 0, -1).eval(),
[[3, 4, 5], [6, 7, 8], [0, 0, 0]])
self.assertAllEqual(
images.translate(arr, 0, 1).eval(), [[0, 0, 0], [0, 1, 2], [3, 4, 5]])
self.assertAllEqual(
images.translate(arr, -1, 0).eval(),
[[1, 2, 0], [4, 5, 0], [7, 8, 0]])
self.assertAllEqual(
images.translate(arr, 1, 0).eval(), [[0, 0, 1], [0, 3, 4], [0, 6, 7]])
if __name__ == '__main__':
tf.test.main()
================================================
FILE: moonlight/vision/morphology.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary morphology ops.
See: https://en.wikipedia.org/wiki/Mathematical_morphology#Binary_morphology
From the link above, these functions use a structuring element of `Z^2`--the
neighbors of a pixel are the pixels above, below, left, and right, which are
assumed to be False if they lie outside the image.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from moonlight.vision import images
def binary_erosion(image, n):
"""The binary erosion of a boolean image.
True pixels that border a False pixel will be set to False.
Args:
image: 2D boolean tensor.
n: Integer scalar tensor. Repeat the erosion `n` times.
Returns:
The eroded image.
"""
with tf.name_scope("binary_erosion"):
image = tf.convert_to_tensor(image, tf.bool, "image")
result = _repeated_morphological_op(tf.to_float(image), tf.logical_and, n)
return tf.cast(result, tf.bool)
def binary_dilation(image, n):
"""The binary dilation of a boolean image.
False pixels that border a True pixel will be set to True.
Args:
image: 2D boolean tensor.
n: Integer scalar tensor. Repeat the dilation `n` times.
Returns:
The dilated image.
"""
with tf.name_scope("binary_dilation"):
image = tf.convert_to_tensor(image, tf.bool, "image")
result = _repeated_morphological_op(tf.to_float(image), tf.logical_or, n)
return tf.cast(result, tf.bool)
def _repeated_morphological_op(float_image, binary_op, n):
def body(i, image):
return i + 1, _single_morphological_op(image, binary_op)
return tf.while_loop(lambda i, _: tf.less(i, n), body,
[tf.constant(0), float_image])[1]
def _single_morphological_op(float_image, binary_op):
with tf.name_scope("_single_morphological_op"):
input_image = float_image
for x, y in [(-1, 0), (0, -1), (1, 0), (0, 1)]:
float_image = tf.to_float(
binary_op(
tf.cast(float_image, tf.bool),
tf.cast(images.translate(input_image, x, y), tf.bool)))
return float_image
================================================
FILE: moonlight/vision/morphology_test.py
================================================
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for binary morphology."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from moonlight.vision import morphology
class MorphologyTest(tf.test.TestCase):
def testMorphology_false(self):
for op in [morphology.binary_erosion, morphology.binary_dilation]:
with self.test_session():
self.assertAllEqual(
op(tf.zeros((5, 3), tf.bool), n=1).eval(), np.zeros((5, 3),
np.bool))
def testErosion_small(self):
with self.test_session():
self.assertAllEqual(
morphology.binary_erosion(
tf.cast([[0, 1, 0], [1, 1, 1], [0, 1, 0]], tf.bool), n=1).eval(),
[[0, 0, 0], [0, 1, 0], [0, 0, 0]])
def testErosion(self):
with self.test_session():
self.assertAllEqual(
morphology.binary_erosion(
tf.cast(
[[1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1],
[1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0]],
tf.bool),
n=1).eval(),
np.asarray(
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
np.bool)) # pyformat: disable
def testDilation(self):
with self.test_session():
self.assertAllEqual(
morphology.binary_dilation(
tf.cast(
[[1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0],
[0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1],
[1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0]],
tf.bool),
n=1).eval(),
np.asarray(
[[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1]],
np.bool)) # pyformat: disable
if __name__ == '__main__':
tf.test.main()
================================================
FILE: requirements.txt
================================================
# List of all packages used by Moonlight.
absl-py
# TODO(ringw): Get the latest apache beam version working in py2 tests.
apache_beam==2.5.0; python_version < '3.0'
apache_beam==2.11.0; python_version >= '3.0'
enum34
librosa==0.4.0
lxml
joblib==0.11.0
Mako
numpy>=1.14.2
pandas
Pillow
protobuf==3.6.1
scipy
tensorflow==1.15.4
tensorflow-estimator==1.15.1
================================================
FILE: sandbox/README.md
================================================
# Moonlight Sandbox
This directory has the symlinks necessary to import Moonlight after building.
You can either add the directory to your PYTHONPATH, or run Python from this
directory.
git clone https://github.com/tensorflow/moonlight
cd moonlight
# You may want to run this inside a virtualenv.
pip install -r requirements.txt
# Builds dependencies and sets up the symlinks that we point to.
bazel build moonlight:omr
cd sandbox
python
>>> from moonlight import engine
================================================
FILE: six.BUILD
================================================
py_library(
name = "six",
srcs = ["six.py"],
visibility = ["//visibility:public"],
srcs_version = "PY2AND3",
)
================================================
FILE: tools/bazel_0.20.0-linux-x86_64.deb.sha256
================================================
ea47050fe839a7f5fb6c3ac1cc876a70993e614bab091aefc387715c3cb48a86 bazel_0.20.0-linux-x86_64.deb
================================================
FILE: tools/travis_tests.sh
================================================
#!/bin/bash
# Print commands before running them.
set -x
# Apache Beam only supports Python 2 :(
# Filter tests with tags = ["py2only"] for Python 3 (the TRAVIS_PYTHON_VERSION
# environment variable starts with a "3").
if [ "${TRAVIS_PYTHON_VERSION:0:1}" = 3 ]; then
PYTHON_VERSION_FILTERS=--test_tag_filters=-py2only
fi
# Test that we can build and import the "engine" module in the sandbox.
bazel build --incompatible_remove_native_http_archive=false //moonlight:omr
PYTHONPATH=sandbox python -m moonlight.engine
bazel test --incompatible_remove_native_http_archive=false \
--test_output=errors --local_test_jobs=1 $PYTHON_VERSION_FILTERS \
//moonlight/...