Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions inference_schema/parameter_types/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def get_swagger_for_list(python_data):
item_type = type(python_data[0])

for data in python_data:
if type(data) != item_type:
raise Exception('Error, OpenAPI 2.x does not support mixed type in array.')

if issubclass(item_type, AbstractParameterType):
nested_item_swagger = data.input_to_swagger()
else:
Expand Down
22 changes: 22 additions & 0 deletions inference_schema/parameter_types/abstract_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ def __init__(self, sample_input):
self.sample_input = sample_input
self.sample_data_type = type(sample_input)

def supported_versions(self):
schema = self.input_to_swagger()
supported_list = ['3.0', '3.1']
if self._supports_swagger_2(schema['example']):
supported_list += ['2.0']
return sorted(supported_list)

def _supports_swagger_2(self, obj):
if type(obj) is list:
first_type = type(obj[0])
for elt in obj:
if type(elt) is not first_type:
return False
elif type(elt) is list:
if not self._supports_swagger_2(elt):
return False
elif type(obj) is dict:
for elt in obj.values():
if not self._supports_swagger_2(elt):
return False
return True

@abstractmethod
def deserialize_input(self, input_data):
"""
Expand Down
68 changes: 63 additions & 5 deletions inference_schema/schema_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import copy
from functools import partial

from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__
from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__, __versions__
from .parameter_types.abstract_parameter_type import AbstractParameterType
from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR

Expand Down Expand Up @@ -39,8 +39,9 @@ def input_schema(param_name, param_type, convert_to_provided_type=True):
'of the AbstractParameterType.')

swagger_schema = {param_name: param_type.input_to_swagger()}
supported_versions = param_type.supported_versions()

@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema)
@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
def decorator_input(user_run, instance, args, kwargs):
if convert_to_provided_type:
args = list(args)
Expand Down Expand Up @@ -82,16 +83,17 @@ def output_schema(output_type):
'of the AbstractParameterType.')

swagger_schema = output_type.input_to_swagger()
supported_versions = output_type.supported_versions()

@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema)
@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
def decorator_input(user_run, instance, args, kwargs):
return user_run(*args, **kwargs)

return decorator_input


# Heavily based on the wrapt.decorator implementation
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None, supported_versions=None):
"""
Decorator to generate decorators, preserving the metadata passed to the
decorator arguments, that is needed to be able to extact that information
Expand All @@ -107,6 +109,8 @@ def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
:type attr_name: str | None
:param schema:
:type schema: dict | None
:param supported_versions:
:type supported_versions: List | None
:return:
:rtype: function | FunctionWrapper
"""
Expand Down Expand Up @@ -134,6 +138,7 @@ def _capture(target_wrapped):
return _capture

_add_schema_to_global_schema_dictionary(attr_name, schema, args[0])
_add_versions_to_global_versions_dictionary(attr_name, supported_versions, args[0])
target_wrapped = args[0]

_enabled = enabled
Expand Down Expand Up @@ -165,7 +170,8 @@ def _capture(target_wrapped):
_schema_decorator,
enabled=enabled,
attr_name=attr_name,
schema=schema
schema=schema,
supported_versions=supported_versions
)


Expand Down Expand Up @@ -201,6 +207,35 @@ def _add_schema_to_global_schema_dictionary(attr_name, schema, user_func):
pass


def _add_versions_to_global_versions_dictionary(attr_name, versions, user_func):
"""
function to add supported swagger versions for 'attr_name', to the function versions dict

:param attr_name:
:type attr_name: str
:param versions:
:type versions: List
:param user_func:
:type user_func: function | FunctionWrapper
:return:
:rtype:
"""

if attr_name is None or versions is None:
pass

decorators = _get_decorators(user_func)
base_func_name = _get_function_full_qual_name(decorators[-1])

if base_func_name not in __versions__.keys():
__versions__[base_func_name] = {}

if attr_name == INPUT_SCHEMA_ATTR or attr_name == OUTPUT_SCHEMA_ATTR:
_add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr_name)
else:
pass


def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, schema):
"""
function to add a generated input schema, to the function schema dict
Expand Down Expand Up @@ -233,6 +268,29 @@ def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, sch
__functions_schema__[base_func_name][INPUT_SCHEMA_ATTR]["properties"][k] = item_swagger


def _add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr):
"""
function to add supported swagger versions to the version dict

:param base_func_name: function full qualified name
:type base_func_name: str
:param versions:
:type versions: list
:param attr:
:type attr: str
:return:
:rtype:
"""

if attr not in __versions__[base_func_name].keys():
__versions__[base_func_name][attr] = {
"type": "object",
"versions": {}
}

__versions__[base_func_name][attr]["versions"] = versions


def _add_output_schema_to_global_schema_dictionary(base_func_name, schema):
"""
function to add a generated output schema, to the function schema dict
Expand Down
19 changes: 19 additions & 0 deletions inference_schema/schema_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from inference_schema._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR

__functions_schema__ = {}
__versions__ = {}


def get_input_schema(func):
Expand Down Expand Up @@ -36,6 +37,24 @@ def get_output_schema(func):
return _get_schema_from_dictionary(OUTPUT_SCHEMA_ATTR, func)


def get_supported_versions(func):
"""
Extract supported swagger versions from the decorated function.

:param func:
:type func: function | FunctionWrapper
:return:
:rtype: list
"""
decorators = _get_decorators(func)
func_base_name = _get_function_full_qual_name(decorators[-1])

input_versions = __versions__.get(func_base_name, {}).get(INPUT_SCHEMA_ATTR, {}).get('versions', [])
output_versions = __versions__.get(func_base_name, {}).get(OUTPUT_SCHEMA_ATTR, {}).get('versions', [])
set_intersection = set(input_versions) & set(output_versions)
return sorted(list(set_intersection))


def get_schemas_dict():
"""
Retrieve a deepcopy of the dictionary that is used to track the provided function schemas
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,27 @@ def standard_py_func(param):
return standard_py_func


@pytest.fixture(scope="session")
def standard_sample_input_multitype_list():
return ['foo', 1]


@pytest.fixture(scope="session")
def standard_sample_output_multitype_list():
return 5


@pytest.fixture(scope="session")
def decorated_standard_func_multitype_list(standard_sample_input_multitype_list, standard_sample_output_multitype_list):
@input_schema('param', StandardPythonParameterType(standard_sample_input_multitype_list))
@output_schema(StandardPythonParameterType(standard_sample_output_multitype_list))
def standard_py_func_multitype_list(param):
assert type(param) is list
return param[1]

return standard_py_func_multitype_list


@pytest.fixture(scope="session")
def decorated_float_func():
@input_schema('param', StandardPythonParameterType(1.0))
Expand Down
6 changes: 6 additions & 0 deletions tests/test_numpy_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------

import numpy as np
from inference_schema.schema_util import get_supported_versions


class TestNumpyParameterType(object):
Expand All @@ -21,3 +22,8 @@ def test_numpy_handling(self, decorated_numpy_func):
numpy_input = {"param": [{"name": "Sarah", "grades": [8.0, 7.0]}]}
result = decorated_numpy_func(**numpy_input)
assert np.array_equal(result, grades)

version_list = get_supported_versions(decorated_numpy_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list
6 changes: 6 additions & 0 deletions tests/test_pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from pandas.testing import assert_frame_equal
from inference_schema.schema_util import get_supported_versions


class TestPandasParameterType(object):
Expand All @@ -25,6 +26,11 @@ def test_pandas_handling(self, decorated_pandas_func):
result = decorated_pandas_func(**pandas_input)
assert_frame_equal(result, state)

version_list = get_supported_versions(decorated_pandas_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list

def test_pandas_orient_handling(self, decorated_pandas_func_split_orient):
pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]}
state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state'])
Expand Down
6 changes: 6 additions & 0 deletions tests/test_spark_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from pyspark.sql.session import SparkSession
from inference_schema.schema_util import get_supported_versions


class TestSparkParameterType(object):
Expand All @@ -25,3 +26,8 @@ def test_spark_handling(self, decorated_spark_func):
spark_input = {'param': [{'name': 'Sarah', 'state': 'WA'}]}
result = decorated_spark_func(**spark_input)
assert state.subtract(result).count() == result.subtract(state).count() == 0

version_list = get_supported_versions(decorated_spark_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list
22 changes: 21 additions & 1 deletion tests/test_standard_parameter_type.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from inference_schema.parameter_types.standard_py_parameter_type import StandardPythonParameterType
from inference_schema.schema_util import get_supported_versions


class TestStandardPythonParameterType(object):

def test_standard_handling(self, decorated_standard_func):
def test_standard_handling_unique(self, decorated_standard_func):
standard_input = {'name': ['Sarah'], 'state': ['WA']}
state = {'state': ['WA']}
result = decorated_standard_func(standard_input)
Expand All @@ -15,6 +17,24 @@ def test_standard_handling(self, decorated_standard_func):
result = decorated_standard_func(**standard_input)
assert state == result

version_list = get_supported_versions(decorated_standard_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list

def test_standard_handling_list(self, decorated_standard_func_multitype_list):
standard_input = ['foo', 1]
assert 1 == decorated_standard_func_multitype_list(standard_input)

version_list = get_supported_versions(decorated_standard_func_multitype_list)
assert '2.0' not in version_list
assert '3.0' in version_list
assert '3.1' in version_list

def test_supported_versions_string(self):
assert '2.0' in StandardPythonParameterType({'name': ['Sarah'], 'state': ['WA']}).supported_versions()
assert '2.0' not in StandardPythonParameterType(['foo', 1]).supported_versions()

def test_float_int_handling(self, decorated_float_func):
float_input = 1.0
result = decorated_float_func(float_input)
Expand Down