Skip to content

Commit e222714

Browse files
authored
Merge pull request #2770 from daspecster/add-manual-detect-to-vision-2697
Add image.detect() for detecting multiple types.
2 parents d11aa28 + fc55c18 commit e222714

3 files changed

Lines changed: 208 additions & 64 deletions

File tree

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2016 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Annotations management for Vision API responses."""
16+
17+
18+
from google.cloud.vision.color import ImagePropertiesAnnotation
19+
from google.cloud.vision.entity import EntityAnnotation
20+
from google.cloud.vision.face import Face
21+
from google.cloud.vision.safe import SafeSearchAnnotation
22+
23+
24+
FACE_ANNOTATIONS = 'faceAnnotations'
25+
IMAGE_PROPERTIES_ANNOTATION = 'imagePropertiesAnnotation'
26+
SAFE_SEARCH_ANNOTATION = 'safeSearchAnnotation'
27+
28+
_KEY_MAP = {
29+
FACE_ANNOTATIONS: 'faces',
30+
IMAGE_PROPERTIES_ANNOTATION: 'properties',
31+
'labelAnnotations': 'labels',
32+
'landmarkAnnotations': 'landmarks',
33+
'logoAnnotations': 'logos',
34+
SAFE_SEARCH_ANNOTATION: 'safe_searches',
35+
'textAnnotations': 'texts'
36+
}
37+
38+
39+
class Annotations(object):
40+
"""Helper class to bundle annotation responses.
41+
42+
:type faces: list
43+
:param faces: List of :class:`~google.cloud.vision.face.Face`.
44+
45+
:type properties: list
46+
:param properties:
47+
List of :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
48+
49+
:type labels: list
50+
:param labels: List of
51+
:class:`~google.cloud.vision.entity.EntityAnnotation`.
52+
53+
:type landmarks: list
54+
:param landmarks: List of
55+
:class:`~google.cloud.vision.entity.EntityAnnotation.`
56+
57+
:type logos: list
58+
:param logos: List of
59+
:class:`~google.cloud.vision.entity.EntityAnnotation`.
60+
61+
:type safe_searches: list
62+
:param safe_searches:
63+
List of :class:`~google.cloud.vision.safe.SafeSearchAnnotation`
64+
65+
:type texts: list
66+
:param texts: List of
67+
:class:`~google.cloud.vision.entity.EntityAnnotation`.
68+
"""
69+
def __init__(self, faces=(), properties=(), labels=(), landmarks=(),
70+
logos=(), safe_searches=(), texts=()):
71+
self.faces = faces
72+
self.properties = properties
73+
self.labels = labels
74+
self.landmarks = landmarks
75+
self.logos = logos
76+
self.safe_searches = safe_searches
77+
self.texts = texts
78+
79+
@classmethod
80+
def from_api_repr(cls, response):
81+
"""Factory: construct an instance of ``Annotations`` from a response.
82+
83+
:type response: dict
84+
:param response: Vision API response object.
85+
86+
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
87+
:returns: An instance of ``Annotations`` with detection types loaded.
88+
"""
89+
annotations = {}
90+
for feature_type, annotation in response.items():
91+
curr_feature = annotations.setdefault(_KEY_MAP[feature_type], [])
92+
curr_feature.extend(
93+
_entity_from_response_type(feature_type, annotation))
94+
return cls(**annotations)
95+
96+
97+
def _entity_from_response_type(feature_type, results):
98+
"""Convert a JSON result to an entity type based on the feature.
99+
100+
:rtype: list
101+
:returns: List containing any of
102+
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
103+
:class:`~google.cloud.vision.entity.EntityAnnotation`,
104+
:class:`~google.cloud.vision.face.Face`,
105+
:class:`~google.cloud.vision.safe.SafeSearchAnnotation`.
106+
"""
107+
detected_objects = []
108+
if feature_type == FACE_ANNOTATIONS:
109+
detected_objects.extend(
110+
Face.from_api_repr(face) for face in results)
111+
elif feature_type == IMAGE_PROPERTIES_ANNOTATION:
112+
detected_objects.append(
113+
ImagePropertiesAnnotation.from_api_repr(results))
114+
elif feature_type == SAFE_SEARCH_ANNOTATION:
115+
detected_objects.append(SafeSearchAnnotation.from_api_repr(results))
116+
else:
117+
for result in results:
118+
detected_objects.append(EntityAnnotation.from_api_repr(result))
119+
return detected_objects

packages/google-cloud-vision/google/cloud/vision/image.py

Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,9 @@
1919

2020
from google.cloud._helpers import _to_bytes
2121
from google.cloud._helpers import _bytes_to_unicode
22-
from google.cloud.vision.entity import EntityAnnotation
23-
from google.cloud.vision.face import Face
22+
from google.cloud.vision.annotations import Annotations
2423
from google.cloud.vision.feature import Feature
2524
from google.cloud.vision.feature import FeatureTypes
26-
from google.cloud.vision.color import ImagePropertiesAnnotation
27-
from google.cloud.vision.safe import SafeSearchAnnotation
28-
29-
30-
_FACE_DETECTION = 'FACE_DETECTION'
31-
_IMAGE_PROPERTIES = 'IMAGE_PROPERTIES'
32-
_LABEL_DETECTION = 'LABEL_DETECTION'
33-
_LANDMARK_DETECTION = 'LANDMARK_DETECTION'
34-
_LOGO_DETECTION = 'LOGO_DETECTION'
35-
_SAFE_SEARCH_DETECTION = 'SAFE_SEARCH_DETECTION'
36-
_TEXT_DETECTION = 'TEXT_DETECTION'
37-
38-
_REVERSE_TYPES = {
39-
_FACE_DETECTION: 'faceAnnotations',
40-
_IMAGE_PROPERTIES: 'imagePropertiesAnnotation',
41-
_LABEL_DETECTION: 'labelAnnotations',
42-
_LANDMARK_DETECTION: 'landmarkAnnotations',
43-
_LOGO_DETECTION: 'logoAnnotations',
44-
_SAFE_SEARCH_DETECTION: 'safeSearchAnnotation',
45-
_TEXT_DETECTION: 'textAnnotations',
46-
}
4725

4826

4927
class Image(object):
@@ -105,7 +83,7 @@ def source(self):
10583
return self._source
10684

10785
def _detect_annotation(self, features):
108-
"""Generic method for detecting a single annotation.
86+
"""Generic method for detecting annotations.
10987
11088
:type features: list
11189
:param features: List of :class:`~google.cloud.vision.feature.Feature`
@@ -118,12 +96,21 @@ def _detect_annotation(self, features):
11896
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
11997
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
12098
"""
121-
detected_objects = []
12299
results = self.client.annotate(self, features)
123-
for feature in features:
124-
detected_objects.extend(
125-
_entity_from_response_type(feature.feature_type, results))
126-
return detected_objects
100+
return Annotations.from_api_repr(results)
101+
102+
def detect(self, features):
103+
"""Detect multiple feature types.
104+
105+
:type features: list of :class:`~google.cloud.vision.feature.Feature`
106+
:param features: List of the ``Feature`` indication the type of
107+
annotation to perform.
108+
109+
:rtype: list
110+
:returns: List of
111+
:class:`~google.cloud.vision.entity.EntityAnnotation`.
112+
"""
113+
return self._detect_annotation(features)
127114

128115
def detect_faces(self, limit=10):
129116
"""Detect faces in image.
@@ -135,7 +122,8 @@ def detect_faces(self, limit=10):
135122
:returns: List of :class:`~google.cloud.vision.face.Face`.
136123
"""
137124
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
138-
return self._detect_annotation(features)
125+
annotations = self._detect_annotation(features)
126+
return annotations.faces
139127

140128
def detect_labels(self, limit=10):
141129
"""Detect labels that describe objects in an image.
@@ -147,7 +135,8 @@ def detect_labels(self, limit=10):
147135
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
148136
"""
149137
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
150-
return self._detect_annotation(features)
138+
annotations = self._detect_annotation(features)
139+
return annotations.labels
151140

152141
def detect_landmarks(self, limit=10):
153142
"""Detect landmarks in an image.
@@ -160,7 +149,8 @@ def detect_landmarks(self, limit=10):
160149
:class:`~google.cloud.vision.entity.EntityAnnotation`.
161150
"""
162151
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
163-
return self._detect_annotation(features)
152+
annotations = self._detect_annotation(features)
153+
return annotations.landmarks
164154

165155
def detect_logos(self, limit=10):
166156
"""Detect logos in an image.
@@ -173,7 +163,8 @@ def detect_logos(self, limit=10):
173163
:class:`~google.cloud.vision.entity.EntityAnnotation`.
174164
"""
175165
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
176-
return self._detect_annotation(features)
166+
annotations = self._detect_annotation(features)
167+
return annotations.logos
177168

178169
def detect_properties(self, limit=10):
179170
"""Detect the color properties of an image.
@@ -186,7 +177,8 @@ def detect_properties(self, limit=10):
186177
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
187178
"""
188179
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
189-
return self._detect_annotation(features)
180+
annotations = self._detect_annotation(features)
181+
return annotations.properties
190182

191183
def detect_safe_search(self, limit=10):
192184
"""Retreive safe search properties from an image.
@@ -199,7 +191,8 @@ def detect_safe_search(self, limit=10):
199191
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
200192
"""
201193
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
202-
return self._detect_annotation(features)
194+
annotations = self._detect_annotation(features)
195+
return annotations.safe_searches
203196

204197
def detect_text(self, limit=10):
205198
"""Detect text in an image.
@@ -212,27 +205,5 @@ def detect_text(self, limit=10):
212205
:class:`~google.cloud.vision.entity.EntityAnnotation`.
213206
"""
214207
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
215-
return self._detect_annotation(features)
216-
217-
218-
def _entity_from_response_type(feature_type, results):
219-
"""Convert a JSON result to an entity type based on the feature."""
220-
feature_key = _REVERSE_TYPES[feature_type]
221-
annotations = results.get(feature_key, ())
222-
if not annotations:
223-
return []
224-
225-
detected_objects = []
226-
if feature_type == _FACE_DETECTION:
227-
detected_objects.extend(
228-
Face.from_api_repr(face) for face in annotations)
229-
elif feature_type == _IMAGE_PROPERTIES:
230-
detected_objects.append(
231-
ImagePropertiesAnnotation.from_api_repr(annotations))
232-
elif feature_type == _SAFE_SEARCH_DETECTION:
233-
detected_objects.append(
234-
SafeSearchAnnotation.from_api_repr(annotations))
235-
else:
236-
for result in annotations:
237-
detected_objects.append(EntityAnnotation.from_api_repr(result))
238-
return detected_objects
208+
annotations = self._detect_annotation(features)
209+
return annotations.texts

packages/google-cloud-vision/unit_tests/test_client.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,60 @@ def test_image_with_client(self):
8181
image = client.image(source_uri=IMAGE_SOURCE)
8282
self.assertIsInstance(image, Image)
8383

84+
def test_multiple_detection_from_content(self):
85+
import copy
86+
from google.cloud.vision.feature import Feature
87+
from google.cloud.vision.feature import FeatureTypes
88+
from unit_tests._fixtures import LABEL_DETECTION_RESPONSE
89+
from unit_tests._fixtures import LOGO_DETECTION_RESPONSE
90+
91+
returned = copy.deepcopy(LABEL_DETECTION_RESPONSE)
92+
logos = copy.deepcopy(LOGO_DETECTION_RESPONSE['responses'][0])
93+
returned['responses'][0]['logoAnnotations'] = logos['logoAnnotations']
94+
95+
credentials = _Credentials()
96+
client = self._make_one(project=PROJECT, credentials=credentials)
97+
client._connection = _Connection(returned)
98+
99+
limit = 2
100+
label_feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
101+
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, limit)
102+
features = [label_feature, logo_feature]
103+
image = client.image(content=IMAGE_CONTENT)
104+
items = image.detect(features)
105+
106+
self.assertEqual(len(items.logos), 2)
107+
self.assertEqual(len(items.labels), 3)
108+
first_logo = items.logos[0]
109+
second_logo = items.logos[1]
110+
self.assertEqual(first_logo.description, 'Brand1')
111+
self.assertEqual(first_logo.score, 0.63192177)
112+
self.assertEqual(second_logo.description, 'Brand2')
113+
self.assertEqual(second_logo.score, 0.5492993)
114+
115+
first_label = items.labels[0]
116+
second_label = items.labels[1]
117+
third_label = items.labels[2]
118+
self.assertEqual(first_label.description, 'automobile')
119+
self.assertEqual(first_label.score, 0.9776855)
120+
self.assertEqual(second_label.description, 'vehicle')
121+
self.assertEqual(second_label.score, 0.947987)
122+
self.assertEqual(third_label.description, 'truck')
123+
self.assertEqual(third_label.score, 0.88429511)
124+
125+
requested = client._connection._requested
126+
requests = requested[0]['data']['requests']
127+
image_request = requests[0]
128+
label_request = image_request['features'][0]
129+
logo_request = image_request['features'][1]
130+
131+
self.assertEqual(B64_IMAGE_CONTENT,
132+
image_request['image']['content'])
133+
self.assertEqual(label_request['maxResults'], 2)
134+
self.assertEqual(label_request['type'], 'LABEL_DETECTION')
135+
self.assertEqual(logo_request['maxResults'], 2)
136+
self.assertEqual(logo_request['type'], 'LOGO_DETECTION')
137+
84138
def test_face_detection_from_source(self):
85139
from google.cloud.vision.face import Face
86140
from unit_tests._fixtures import FACE_DETECTION_RESPONSE
@@ -126,7 +180,7 @@ def test_face_detection_from_content_no_results(self):
126180

127181
image = client.image(content=IMAGE_CONTENT)
128182
faces = image.detect_faces(limit=5)
129-
self.assertEqual(faces, [])
183+
self.assertEqual(faces, ())
130184
self.assertEqual(len(faces), 0)
131185
image_request = client._connection._requested[0]['data']['requests'][0]
132186

@@ -166,7 +220,7 @@ def test_label_detection_no_results(self):
166220

167221
image = client.image(content=IMAGE_CONTENT)
168222
labels = image.detect_labels()
169-
self.assertEqual(labels, [])
223+
self.assertEqual(labels, ())
170224
self.assertEqual(len(labels), 0)
171225

172226
def test_landmark_detection_from_source(self):
@@ -219,7 +273,7 @@ def test_landmark_detection_no_results(self):
219273

220274
image = client.image(content=IMAGE_CONTENT)
221275
landmarks = image.detect_landmarks()
222-
self.assertEqual(landmarks, [])
276+
self.assertEqual(landmarks, ())
223277
self.assertEqual(len(landmarks), 0)
224278

225279
def test_logo_detection_from_source(self):
@@ -308,7 +362,7 @@ def test_safe_search_no_results(self):
308362

309363
image = client.image(content=IMAGE_CONTENT)
310364
safe_search = image.detect_safe_search()
311-
self.assertEqual(safe_search, [])
365+
self.assertEqual(safe_search, ())
312366
self.assertEqual(len(safe_search), 0)
313367

314368
def test_image_properties_detection_from_source(self):
@@ -344,7 +398,7 @@ def test_image_properties_no_results(self):
344398

345399
image = client.image(content=IMAGE_CONTENT)
346400
image_properties = image.detect_properties()
347-
self.assertEqual(image_properties, [])
401+
self.assertEqual(image_properties, ())
348402
self.assertEqual(len(image_properties), 0)
349403

350404

0 commit comments

Comments
 (0)