Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions inference_schema/parameter_types/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ def get_swagger_for_list(python_data):
item_type = type(python_data[0])

for data in python_data:
if type(data) != item_type:
raise Exception('Error, OpenAPI 2.x does not support mixed type in array.')

if issubclass(item_type, AbstractParameterType):
nested_item_swagger = data.input_to_swagger()
else:
Expand Down Expand Up @@ -100,3 +97,26 @@ def get_swagger_for_nested_dict(python_data):

schema = {"type": "object", "required": required, "properties": nested_items, "example": examples}
return schema


def get_supported_versions_from_schema(schema):
supported_list = ['3.0', '3.1']
if _supports_swagger_2(schema['example']):
supported_list += ['2.0']
return sorted(supported_list)


def _supports_swagger_2(object):
Comment thread
wamartin-aml marked this conversation as resolved.
Outdated
if type(object) is list:
first_type = type(object[0])
for elt in object:
if type(elt) is not first_type:
return False
elif type(elt) is list:
if not _supports_swagger_2(elt):
return False
elif type(object) is dict:
for elt in object.values():
if not _supports_swagger_2(elt):
return False
return True
7 changes: 7 additions & 0 deletions inference_schema/parameter_types/abstract_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def input_to_swagger(self):
"""
pass

@abstractmethod
def supported_versions(self):
"""
Abstract method to be overridden by concrete types. Used to return supported OpenApi swagger versions.
"""
pass

@classmethod
def _date_item_to_string(cls, date_item):
return date_item.astype(dt.datetime).strftime("%Y-%m-%d")
Expand Down
4 changes: 4 additions & 0 deletions inference_schema/parameter_types/numpy_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .abstract_parameter_type import AbstractParameterType
from ._swagger_from_dtype import Dtype2Swagger
from ._constants import SWAGGER_FORMAT_CONSTANTS
from ._util import get_supported_versions_from_schema


class NumpyParameterType(AbstractParameterType):
Expand Down Expand Up @@ -36,6 +37,9 @@ def __init__(self, sample_input, enforce_column_type=True, enforce_shape=True):
self.enforce_column_type = enforce_column_type
self.enforce_shape = enforce_shape

def supported_versions(self):
Comment thread
wamartin-aml marked this conversation as resolved.
Outdated
return get_supported_versions_from_schema(self.input_to_swagger())

def deserialize_input(self, input_data):
"""
Convert the provided array-like object into a numpy array. Will attempt to enforce column type and array shape
Expand Down
5 changes: 4 additions & 1 deletion inference_schema/parameter_types/pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import pandas as pd
from .abstract_parameter_type import AbstractParameterType
from ._util import get_swagger_for_list, get_swagger_for_nested_dict
from ._util import get_swagger_for_list, get_swagger_for_nested_dict, get_supported_versions_from_schema
from ._constants import SWAGGER_FORMAT_CONSTANTS


Expand Down Expand Up @@ -49,6 +49,9 @@ def __init__(self, sample_input, enforce_column_type=True, enforce_shape=True, a
"'values', or 'table')")
self.orient = orient

def supported_versions(self):
return get_supported_versions_from_schema(self.input_to_swagger())

def deserialize_input(self, input_data):
"""
Convert the provided pandas-like object into a pandas dataframe. Will attempt to enforce column type and array
Expand Down
4 changes: 4 additions & 0 deletions inference_schema/parameter_types/spark_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyspark.sql.types import MapType
from pyspark.sql.types import UserDefinedType
from .abstract_parameter_type import AbstractParameterType
from ._util import get_supported_versions_from_schema


class SparkParameterType(AbstractParameterType):
Expand All @@ -41,6 +42,9 @@ def __init__(self, sample_input, apply_sample_schema=True):
super(SparkParameterType, self).__init__(sample_input)
self.apply_sample_schema = apply_sample_schema

def supported_versions(self):
return get_supported_versions_from_schema(self.input_to_swagger())

def deserialize_input(self, input_data):
"""
Convert the provided spark-like object into a spark dataframe. Will attempt to enforce column type and array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dateutil import parser
from .abstract_parameter_type import AbstractParameterType
from ._constants import DATE_FORMAT, ERR_PYTHON_DATA_NOT_JSON_SERIALIZABLE
from ._util import handle_standard_types
from ._util import handle_standard_types, get_supported_versions_from_schema


class StandardPythonParameterType(AbstractParameterType):
Expand Down Expand Up @@ -40,6 +40,9 @@ def __init__(self, sample_input):
if issubclass(type(data), AbstractParameterType):
self.sample_data_type_list.append(data)

def supported_versions(self):
Comment thread
wamartin-aml marked this conversation as resolved.
Outdated
return get_supported_versions_from_schema(self.input_to_swagger())

def deserialize_input(self, input_data):
"""
Convert the provided data into the expected Python object.
Expand Down
68 changes: 63 additions & 5 deletions inference_schema/schema_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import copy
from functools import partial

from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__
from .schema_util import _get_decorators, _get_function_full_qual_name, __functions_schema__, __versions__
from .parameter_types.abstract_parameter_type import AbstractParameterType
from ._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR

Expand Down Expand Up @@ -39,8 +39,9 @@ def input_schema(param_name, param_type, convert_to_provided_type=True):
'of the AbstractParameterType.')

swagger_schema = {param_name: param_type.input_to_swagger()}
supported_versions = param_type.supported_versions()

@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema)
@_schema_decorator(attr_name=INPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
def decorator_input(user_run, instance, args, kwargs):
if convert_to_provided_type:
args = list(args)
Expand Down Expand Up @@ -82,16 +83,17 @@ def output_schema(output_type):
'of the AbstractParameterType.')

swagger_schema = output_type.input_to_swagger()
supported_versions = output_type.supported_versions()

@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema)
@_schema_decorator(attr_name=OUTPUT_SCHEMA_ATTR, schema=swagger_schema, supported_versions=supported_versions)
def decorator_input(user_run, instance, args, kwargs):
return user_run(*args, **kwargs)

return decorator_input


# Heavily based on the wrapt.decorator implementation
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None, supported_versions=None):
"""
Decorator to generate decorators, preserving the metadata passed to the
decorator arguments, that is needed to be able to extact that information
Expand All @@ -107,6 +109,8 @@ def _schema_decorator(wrapper=None, enabled=None, attr_name=None, schema=None):
:type attr_name: str | None
:param schema:
:type schema: dict | None
:param supported_versions:
:type supported_versions: List | None
:return:
:rtype: function | FunctionWrapper
"""
Expand Down Expand Up @@ -134,6 +138,7 @@ def _capture(target_wrapped):
return _capture

_add_schema_to_global_schema_dictionary(attr_name, schema, args[0])
_add_versions_to_global_versions_dictionary(attr_name, supported_versions, args[0])
target_wrapped = args[0]

_enabled = enabled
Expand Down Expand Up @@ -165,7 +170,8 @@ def _capture(target_wrapped):
_schema_decorator,
enabled=enabled,
attr_name=attr_name,
schema=schema
schema=schema,
supported_versions=supported_versions
)


Expand Down Expand Up @@ -201,6 +207,35 @@ def _add_schema_to_global_schema_dictionary(attr_name, schema, user_func):
pass


def _add_versions_to_global_versions_dictionary(attr_name, versions, user_func):
"""
function to add supported swagger versions for 'attr_name', to the function versions dict

:param attr_name:
:type attr_name: str
:param versions:
:type versions: List
:param user_func:
:type user_func: function | FunctionWrapper
:return:
:rtype:
"""

if attr_name is None or versions is None:
pass

decorators = _get_decorators(user_func)
base_func_name = _get_function_full_qual_name(decorators[-1])

if base_func_name not in __versions__.keys():
__versions__[base_func_name] = {}

if attr_name == INPUT_SCHEMA_ATTR or attr_name == OUTPUT_SCHEMA_ATTR:
_add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr_name)
else:
pass


def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, schema):
"""
function to add a generated input schema, to the function schema dict
Expand Down Expand Up @@ -233,6 +268,29 @@ def _add_input_schema_to_global_schema_dictionary(base_func_name, arg_names, sch
__functions_schema__[base_func_name][INPUT_SCHEMA_ATTR]["properties"][k] = item_swagger


def _add_attr_versions_to_global_schema_dictionary(base_func_name, versions, attr):
"""
function to add supported swagger versions to the version dict

:param base_func_name: function full qualified name
:type base_func_name: str
:param versions:
:type versions: list
:param attr:
:type attr: str
:return:
:rtype:
"""

if attr not in __versions__[base_func_name].keys():
__versions__[base_func_name][attr] = {
"type": "object",
"versions": {}
}

__versions__[base_func_name][attr]["versions"] = versions


def _add_output_schema_to_global_schema_dictionary(base_func_name, schema):
"""
function to add a generated output schema, to the function schema dict
Expand Down
19 changes: 19 additions & 0 deletions inference_schema/schema_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from inference_schema._constants import INPUT_SCHEMA_ATTR, OUTPUT_SCHEMA_ATTR

__functions_schema__ = {}
__versions__ = {}


def get_input_schema(func):
Expand Down Expand Up @@ -36,6 +37,24 @@ def get_output_schema(func):
return _get_schema_from_dictionary(OUTPUT_SCHEMA_ATTR, func)


def get_supported_versions(func):
"""
Extract supported swagger versions from the decorated function.

:param func:
:type func: function | FunctionWrapper
:return:
:rtype: list
"""
decorators = _get_decorators(func)
func_base_name = _get_function_full_qual_name(decorators[-1])

input_versions = __versions__.get(func_base_name, {}).get(INPUT_SCHEMA_ATTR, {}).get('versions', [])
output_versions = __versions__.get(func_base_name, {}).get(OUTPUT_SCHEMA_ATTR, {}).get('versions', [])
set_intersection = set(input_versions) & set(output_versions)
return sorted(list(set_intersection))


def get_schemas_dict():
"""
Retrieve a deepcopy of the dictionary that is used to track the provided function schemas
Expand Down
6 changes: 6 additions & 0 deletions tests/test_numpy_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------

import numpy as np
from inference_schema.schema_util import get_supported_versions


class TestNumpyParameterType(object):
Expand All @@ -21,3 +22,8 @@ def test_numpy_handling(self, decorated_numpy_func):
numpy_input = {"param": [{"name": "Sarah", "grades": [8.0, 7.0]}]}
result = decorated_numpy_func(**numpy_input)
assert np.array_equal(result, grades)

version_list = get_supported_versions(decorated_numpy_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list
6 changes: 6 additions & 0 deletions tests/test_pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from pandas.testing import assert_frame_equal
from inference_schema.schema_util import get_supported_versions


class TestPandasParameterType(object):
Expand All @@ -25,6 +26,11 @@ def test_pandas_handling(self, decorated_pandas_func):
result = decorated_pandas_func(**pandas_input)
assert_frame_equal(result, state)

version_list = get_supported_versions(decorated_pandas_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list

def test_pandas_orient_handling(self, decorated_pandas_func_split_orient):
pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]}
state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state'])
Expand Down
6 changes: 6 additions & 0 deletions tests/test_spark_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd

from pyspark.sql.session import SparkSession
from inference_schema.schema_util import get_supported_versions


class TestSparkParameterType(object):
Expand All @@ -25,3 +26,8 @@ def test_spark_handling(self, decorated_spark_func):
spark_input = {'param': [{'name': 'Sarah', 'state': 'WA'}]}
result = decorated_spark_func(**spark_input)
assert state.subtract(result).count() == result.subtract(state).count() == 0

version_list = get_supported_versions(decorated_spark_func)
assert '2.0' in version_list
assert '3.0' in version_list
assert '3.1' in version_list
Loading