Skip to content

Commit 8c79102

Browse files
committed
Add CORS support to buckets.
See: http://www.w3.org/TR/cors/ and https://cloud.google.com/storage/docs/json_api/v1/buckets Addresses 'cors' part of 314.
1 parent aa22386 commit 8c79102

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

gcloud/storage/bucket.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Bucket(_MetadataMixin):
2424
CUSTOM_METADATA_FIELDS = {
2525
'acl': 'get_acl',
2626
'defaultObjectAcl': 'get_default_object_acl',
27+
'cors': 'get_cors',
2728
}
2829
"""Mapping of field name -> accessor for fields w/ custom accessors."""
2930

@@ -441,6 +442,58 @@ def make_public(self, recursive=False, future=False):
441442
key.get_acl().all().grant_read()
442443
key.save_acl()
443444

445+
def get_cors(self):
446+
"""Retrieve CORS policies configured for this bucket.
447+
448+
See: http://www.w3.org/TR/cors/ and
449+
https://cloud.google.com/storage/docs/json_api/v1/buckets
450+
451+
:rtype: list(dict)
452+
:returns: A sequence of mappings describing each CORS policy.
453+
Keys include 'max_age', 'methods', 'origins', and
454+
'headers'.
455+
"""
456+
if not self.has_metadata('cors'):
457+
self.reload_metadata()
458+
result = []
459+
for entry in self.metadata.get('cors', ()):
460+
entry = entry.copy()
461+
result.append(entry)
462+
if 'maxAgeSeconds' in entry:
463+
entry['max_age'] = entry.pop('maxAgeSeconds')
464+
if 'method' in entry:
465+
entry['methods'] = entry.pop('method')
466+
if 'origin' in entry:
467+
entry['origins'] = entry.pop('origin')
468+
if 'responseHeader' in entry:
469+
entry['headers'] = entry.pop('responseHeader')
470+
return result
471+
472+
def update_cors(self, entries):
473+
"""Update CORS policies configured for this bucket.
474+
475+
See: http://www.w3.org/TR/cors/ and
476+
https://cloud.google.com/storage/docs/json_api/v1/buckets
477+
478+
:type entries: list(dict)
479+
:param entries: A sequence of mappings describing each CORS policy.
480+
Keys include 'max_age', 'methods', 'origins', and
481+
'headers'.
482+
"""
483+
to_patch = []
484+
for entry in entries:
485+
entry = entry.copy()
486+
to_patch.append(entry)
487+
if 'max_age' in entry:
488+
entry['maxAgeSeconds'] = entry.pop('max_age')
489+
if 'methods' in entry:
490+
entry['method'] = entry.pop('methods')
491+
if 'origins' in entry:
492+
entry['origin'] = entry.pop('origins')
493+
if 'headers' in entry:
494+
entry['responseHeader'] = entry.pop('headers')
495+
self.patch_metadata({'cors': to_patch})
496+
444497

445498
class BucketIterator(Iterator):
446499
"""An iterator listing all buckets.

gcloud/storage/test_bucket.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,23 @@ def test_get_metadata_none_set_defaultObjectAcl_miss_clear_default(self):
489489
kw = connection._requested
490490
self.assertEqual(len(kw), 0)
491491

492+
def test_get_metadata_cors_no_default(self):
493+
NAME = 'name'
494+
connection = _Connection()
495+
bucket = self._makeOne(connection, NAME)
496+
self.assertRaises(KeyError, bucket.get_metadata, 'cors')
497+
kw = connection._requested
498+
self.assertEqual(len(kw), 0)
499+
500+
def test_get_metadata_none_set_cors_w_default(self):
501+
NAME = 'name'
502+
connection = _Connection()
503+
bucket = self._makeOne(connection, NAME)
504+
default = object()
505+
self.assertRaises(KeyError, bucket.get_metadata, 'cors', default)
506+
kw = connection._requested
507+
self.assertEqual(len(kw), 0)
508+
492509
def test_get_metadata_miss(self):
493510
NAME = 'name'
494511
before = {'bar': 'Bar'}
@@ -713,6 +730,77 @@ def get_items_from_response(self, response):
713730
self.assertEqual(kw[1]['path'], '/b/%s/o' % NAME)
714731
self.assertEqual(kw[1]['query_params'], None)
715732

733+
def test_get_cors_eager(self):
734+
NAME = 'name'
735+
CORS_ENTRY = {
736+
'maxAgeSeconds': 1234,
737+
'method': ['OPTIONS', 'GET'],
738+
'origin': ['127.0.0.1'],
739+
'responseHeader': ['Content-Type'],
740+
}
741+
before = {'cors': [CORS_ENTRY, {}]}
742+
connection = _Connection()
743+
bucket = self._makeOne(connection, NAME, before)
744+
entries = bucket.get_cors()
745+
self.assertEqual(len(entries), 2)
746+
self.assertEqual(entries[0]['max_age'], CORS_ENTRY['maxAgeSeconds'])
747+
self.assertEqual(entries[0]['methods'], CORS_ENTRY['method'])
748+
self.assertEqual(entries[0]['origins'], CORS_ENTRY['origin'])
749+
self.assertEqual(entries[0]['headers'], CORS_ENTRY['responseHeader'])
750+
self.assertEqual(entries[1], {})
751+
kw = connection._requested
752+
self.assertEqual(len(kw), 0)
753+
754+
def test_get_cors_lazy(self):
755+
NAME = 'name'
756+
CORS_ENTRY = {
757+
'maxAgeSeconds': 1234,
758+
'method': ['OPTIONS', 'GET'],
759+
'origin': ['127.0.0.1'],
760+
'responseHeader': ['Content-Type'],
761+
}
762+
after = {'cors': [CORS_ENTRY]}
763+
connection = _Connection(after)
764+
bucket = self._makeOne(connection, NAME)
765+
entries = bucket.get_cors()
766+
self.assertEqual(len(entries), 1)
767+
self.assertEqual(entries[0]['max_age'], CORS_ENTRY['maxAgeSeconds'])
768+
self.assertEqual(entries[0]['methods'], CORS_ENTRY['method'])
769+
self.assertEqual(entries[0]['origins'], CORS_ENTRY['origin'])
770+
self.assertEqual(entries[0]['headers'], CORS_ENTRY['responseHeader'])
771+
kw = connection._requested
772+
self.assertEqual(len(kw), 1)
773+
self.assertEqual(kw[0]['method'], 'GET')
774+
self.assertEqual(kw[0]['path'], '/b/%s' % NAME)
775+
self.assertEqual(kw[0]['query_params'], {'projection': 'noAcl'})
776+
777+
def test_update_cors(self):
778+
NAME = 'name'
779+
CORS_ENTRY = {
780+
'maxAgeSeconds': 1234,
781+
'method': ['OPTIONS', 'GET'],
782+
'origin': ['127.0.0.1'],
783+
'responseHeader': ['Content-Type'],
784+
}
785+
MAPPED = {
786+
'max_age': 1234,
787+
'methods': ['OPTIONS', 'GET'],
788+
'origins': ['127.0.0.1'],
789+
'headers': ['Content-Type'],
790+
}
791+
after = {'cors': [CORS_ENTRY, {}]}
792+
connection = _Connection(after)
793+
bucket = self._makeOne(connection, NAME)
794+
bucket.update_cors([MAPPED, {}])
795+
kw = connection._requested
796+
self.assertEqual(len(kw), 1)
797+
self.assertEqual(kw[0]['method'], 'PATCH')
798+
self.assertEqual(kw[0]['path'], '/b/%s' % NAME)
799+
self.assertEqual(kw[0]['data'], after)
800+
self.assertEqual(kw[0]['query_params'], {'projection': 'full'})
801+
entries = bucket.get_cors()
802+
self.assertEqual(entries, [MAPPED, {}])
803+
716804

717805
class TestBucketIterator(unittest2.TestCase):
718806

0 commit comments

Comments
 (0)