diff --git a/inference_schema/parameter_types/_util.py b/inference_schema/parameter_types/_util.py index ed5bec0..a321df8 100644 --- a/inference_schema/parameter_types/_util.py +++ b/inference_schema/parameter_types/_util.py @@ -68,9 +68,6 @@ def get_swagger_for_list(python_data): item_type = type(python_data[0]) for data in python_data: - if type(data) != item_type: - raise Exception('Error, OpenAPI 2.x does not support mixed type in array.') - if issubclass(item_type, AbstractParameterType): nested_item_swagger = data.input_to_swagger() else: diff --git a/inference_schema/parameter_types/abstract_parameter_type.py b/inference_schema/parameter_types/abstract_parameter_type.py index 71a8687..aa5e5f3 100644 --- a/inference_schema/parameter_types/abstract_parameter_type.py +++ b/inference_schema/parameter_types/abstract_parameter_type.py @@ -24,6 +24,28 @@ def __init__(self, sample_input): self.sample_input = sample_input self.sample_data_type = type(sample_input) + def supported_versions(self): + schema = self.input_to_swagger() + supported_list = ['3.0', '3.1'] + if self._supports_swagger_2(schema['example']): + supported_list += ['2.0'] + return sorted(supported_list) + + def _supports_swagger_2(self, obj): + if type(obj) is list: + first_type = type(obj[0]) + for elt in obj: + if type(elt) is not first_type: + return False + elif type(elt) is list: + if not self._supports_swagger_2(elt): + return False + elif type(obj) is dict: + for elt in obj.values(): + if not self._supports_swagger_2(elt): + return False + return True + @abstractmethod def deserialize_input(self, input_data): """ diff --git a/inference_schema/schema_decorators.py b/inference_schema/schema_decorators.py index 81adf91..d5f2c0f 100644 --- a/inference_schema/schema_decorators.py +++ b/inference_schema/schema_decorators.py @@ -7,7 +7,7 @@ import copy from functools import partial -from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__ +from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__, __versions__ from .parameter_types.abstract_parameter_type import AbstractParameterType from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR @@ -39,8 +39,9 @@ def input_schema(param_name, param_type, convert_to_provided_type=True): 'of the AbstractParameterType.') swagger_schema = {param_name: param_type.input_to_swagger()} + supported_versions = param_type.supported_versions() - @_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema) + @_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions) def decorator_input(user_run, instance, args, kwargs): if convert_to_provided_type: args = list(args) @@ -82,8 +83,9 @@ def output_schema(output_type): 'of the AbstractParameterType.') swagger_schema = output_type.input_to_swagger() + supported_versions = output_type.supported_versions() - @_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema) + @_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions) def decorator_input(user_run, instance, args, kwargs): return user_run(*args, **kwargs) @@ -91,7 +93,7 @@ def decorator_input(user_run, instance, args, kwargs): # Heavily based on the wrapt.decorator implementation -def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None): +def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None, supported_versions=None): """ Decorator to generate decorators, preserving the metadata passed to the decorator arguments, that is needed to be able to extact that information @@ -107,6 +109,8 @@ def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None): :type attr_name: str | None :param schema: :type schema: dict | None + :param supported_versions: + :type supported_versions: List | None :return: :rtype: function | FunctionWrapper """ @@ -134,6 +138,7 @@ def _capture(target_wrapped): return _capture _add_schema_to_global_schema_dictionary(attr_name, schema, args[0]) + _add_versions_to_global_versions_dictionary(attr_name, supported_versions, args[0]) target_wrapped = args[0] _enabled = enabled @@ -165,7 +170,8 @@ def _capture(target_wrapped): _schema_decorator, enabled=enabled, attr_name=attr_name, - schema=schema + schema=schema, + supported_versions=supported_versions ) @@ -201,6 +207,35 @@ def _add_schema_to_global_schema_dictionary(attr_name, schema, user_func): pass +def _add_versions_to_global_versions_dictionary(attr_name, versions, user_func): + """ + function to add supported swagger versions for 'attr_name', to the function versions dict + + :param attr_name: + :type attr_name: str + :param versions: + :type versions: List + :param user_func: + :type user_func: function | FunctionWrapper + :return: + :rtype: + """ + + if attr_name is None or versions is None: + pass + + decorators = _get_decorators(user_func) + base_func_name = _get_function_full_qual_name(decorators[-1]) + + if base_func_name not in __versions__.keys(): + __versions__[base_func_name] = {} + + if attr_name == INPUT_SCHEMA_ATTR or attr_name == OUTPUT_SCHEMA_ATTR: + _add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr_name) + else: + pass + + def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, schema): """ function to add a generated input schema, to the function schema dict @@ -233,6 +268,29 @@ def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, sch __functions_schema__[base_func_name][INPUT_SCHEMA_ATTR]["properties"][k] = item_swagger +def _add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr): + """ + function to add supported swagger versions to the version dict + + :param base_func_name: function full qualified name + :type base_func_name: str + :param versions: + :type versions: list + :param attr: + :type attr: str + :return: + :rtype: + """ + + if attr not in __versions__[base_func_name].keys(): + __versions__[base_func_name][attr] = { + "type": "object", + "versions": {} + } + + __versions__[base_func_name][attr]["versions"] = versions + + def _add_output_schema_to_global_schema_dictionary(base_func_name, schema): """ function to add a generated output schema, to the function schema dict diff --git a/inference_schema/schema_util.py b/inference_schema/schema_util.py index a652ba1..9335226 100644 --- a/inference_schema/schema_util.py +++ b/inference_schema/schema_util.py @@ -8,6 +8,7 @@ from inference_schema._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR __functions_schema__ = {} +__versions__ = {} def get_input_schema(func): @@ -36,6 +37,24 @@ def get_output_schema(func): return _get_schema_from_dictionary(OUTPUT_SCHEMA_ATTR, func) +def get_supported_versions(func): + """ + Extract supported swagger versions from the decorated function. + + :param func: + :type func: function | FunctionWrapper + :return: + :rtype: list + """ + decorators = _get_decorators(func) + func_base_name = _get_function_full_qual_name(decorators[-1]) + + input_versions = __versions__.get(func_base_name, {}).get(INPUT_SCHEMA_ATTR, {}).get('versions', []) + output_versions = __versions__.get(func_base_name, {}).get(OUTPUT_SCHEMA_ATTR, {}).get('versions', []) + set_intersection = set(input_versions) & set(output_versions) + return sorted(list(set_intersection)) + + def get_schemas_dict(): """ Retrieve a deepcopy of the dictionary that is used to track the provided function schemas diff --git a/tests/conftest.py b/tests/conftest.py index 4c8aa7a..6adf7fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,6 +184,27 @@ def standard_py_func(param): return standard_py_func +@pytest.fixture(scope="session") +def standard_sample_input_multitype_list(): + return ['foo', 1] + + +@pytest.fixture(scope="session") +def standard_sample_output_multitype_list(): + return 5 + + +@pytest.fixture(scope="session") +def decorated_standard_func_multitype_list(standard_sample_input_multitype_list, standard_sample_output_multitype_list): + @input_schema('param', StandardPythonParameterType(standard_sample_input_multitype_list)) + @output_schema(StandardPythonParameterType(standard_sample_output_multitype_list)) + def standard_py_func_multitype_list(param): + assert type(param) is list + return param[1] + + return standard_py_func_multitype_list + + @pytest.fixture(scope="session") def decorated_float_func(): @input_schema('param', StandardPythonParameterType(1.0)) diff --git a/tests/test_numpy_parameter_type.py b/tests/test_numpy_parameter_type.py index 30f0311..81a5d8a 100644 --- a/tests/test_numpy_parameter_type.py +++ b/tests/test_numpy_parameter_type.py @@ -3,6 +3,7 @@ # --------------------------------------------------------- import numpy as np +from inference_schema.schema_util import get_supported_versions class TestNumpyParameterType(object): @@ -21,3 +22,8 @@ def test_numpy_handling(self, decorated_numpy_func): numpy_input = {"param": [{"name": "Sarah", "grades": [8.0, 7.0]}]} result = decorated_numpy_func(**numpy_input) assert np.array_equal(result, grades) + + version_list = get_supported_versions(decorated_numpy_func) + assert '2.0' in version_list + assert '3.0' in version_list + assert '3.1' in version_list diff --git a/tests/test_pandas_parameter_type.py b/tests/test_pandas_parameter_type.py index 590af24..7dcba80 100644 --- a/tests/test_pandas_parameter_type.py +++ b/tests/test_pandas_parameter_type.py @@ -7,6 +7,7 @@ import pandas as pd from pandas.testing import assert_frame_equal +from inference_schema.schema_util import get_supported_versions class TestPandasParameterType(object): @@ -25,6 +26,11 @@ def test_pandas_handling(self, decorated_pandas_func): result = decorated_pandas_func(**pandas_input) assert_frame_equal(result, state) + version_list = get_supported_versions(decorated_pandas_func) + assert '2.0' in version_list + assert '3.0' in version_list + assert '3.1' in version_list + def test_pandas_orient_handling(self, decorated_pandas_func_split_orient): pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]} state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state']) diff --git a/tests/test_spark_parameter_type.py b/tests/test_spark_parameter_type.py index 0fffd0f..9df7042 100644 --- a/tests/test_spark_parameter_type.py +++ b/tests/test_spark_parameter_type.py @@ -5,6 +5,7 @@ import pandas as pd from pyspark.sql.session import SparkSession +from inference_schema.schema_util import get_supported_versions class TestSparkParameterType(object): @@ -25,3 +26,8 @@ def test_spark_handling(self, decorated_spark_func): spark_input = {'param': [{'name': 'Sarah', 'state': 'WA'}]} result = decorated_spark_func(**spark_input) assert state.subtract(result).count() == result.subtract(state).count() == 0 + + version_list = get_supported_versions(decorated_spark_func) + assert '2.0' in version_list + assert '3.0' in version_list + assert '3.1' in version_list diff --git a/tests/test_standard_parameter_type.py b/tests/test_standard_parameter_type.py index 2487ce3..cb5e090 100644 --- a/tests/test_standard_parameter_type.py +++ b/tests/test_standard_parameter_type.py @@ -1,11 +1,13 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +from inference_schema.parameter_types.standard_py_parameter_type import StandardPythonParameterType +from inference_schema.schema_util import get_supported_versions class TestStandardPythonParameterType(object): - def test_standard_handling(self, decorated_standard_func): + def test_standard_handling_unique(self, decorated_standard_func): standard_input = {'name': ['Sarah'], 'state': ['WA']} state = {'state': ['WA']} result = decorated_standard_func(standard_input) @@ -15,6 +17,24 @@ def test_standard_handling(self, decorated_standard_func): result = decorated_standard_func(**standard_input) assert state == result + version_list = get_supported_versions(decorated_standard_func) + assert '2.0' in version_list + assert '3.0' in version_list + assert '3.1' in version_list + + def test_standard_handling_list(self, decorated_standard_func_multitype_list): + standard_input = ['foo', 1] + assert 1 == decorated_standard_func_multitype_list(standard_input) + + version_list = get_supported_versions(decorated_standard_func_multitype_list) + assert '2.0' not in version_list + assert '3.0' in version_list + assert '3.1' in version_list + + def test_supported_versions_string(self): + assert '2.0' in StandardPythonParameterType({'name': ['Sarah'], 'state': ['WA']}).supported_versions() + assert '2.0' not in StandardPythonParameterType(['foo', 1]).supported_versions() + def test_float_int_handling(self, decorated_float_func): float_input = 1.0 result = decorated_float_func(float_input)