Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,21 @@ def getter(
}
for key in set(data) - fields:
value = data[key]
if unknown == INCLUDE:
if isinstance(unknown, ma_fields.Field):

def getter(val, unknown_field=unknown, field_name=key):
return unknown_field.deserialize(val, field_name, data)

deserialized = self._call_and_store(
getter_func=getter,
data=value,
field_name=key,
error_store=error_store,
index=index,
)
if deserialized is not missing:
ret_d[key] = deserialized
elif unknown == INCLUDE:
ret_d[key] = value
elif unknown == RAISE:
error_store.store_error(
Expand Down
15 changes: 13 additions & 2 deletions src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,25 @@

import typing

if typing.TYPE_CHECKING:
from marshmallow.fields import Field

#: A type that can be either a sequence of strings or a set of strings
StrSequenceOrSet: typing.TypeAlias = typing.Sequence[str] | typing.AbstractSet[str]

#: Type for validator functions
Validator: typing.TypeAlias = typing.Callable[[typing.Any], typing.Any]

#: A valid option for the ``unknown`` schema option and argument
UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"]
#: A valid option for the ``unknown`` schema option and argument.
#: Can be a string constant (``"exclude"``, ``"include"``, ``"raise"``)
#: or a :class:`Field <marshmallow.fields.Field>` instance to deserialize unknown
#: field values through.
if typing.TYPE_CHECKING:
UnknownOption: typing.TypeAlias = (
typing.Literal["exclude", "include", "raise"] | Field
)
else:
UnknownOption: typing.TypeAlias = typing.Literal["exclude", "include", "raise"]


class SchemaValidator(typing.Protocol):
Expand Down
69 changes: 69 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,75 @@ class ErrorSchema(Schema):
assert "Invalid email" in errors["email"]


class TestUnknownFieldOption:
"""Tests for passing a Field instance to ``unknown``."""

def test_unknown_field_deserializes_values(self):
class MySchema(Schema):
name = fields.String()

schema = MySchema(unknown=fields.Int())
result = schema.load({"name": "Joe", "age": "42"})
assert result == {"name": "Joe", "age": 42}

def test_unknown_field_validation_error(self):
class MySchema(Schema):
name = fields.String()

schema = MySchema(unknown=fields.Int())
with pytest.raises(ValidationError) as excinfo:
schema.load({"name": "Joe", "age": "not_a_number"})
assert "age" in excinfo.value.messages

def test_unknown_field_in_meta(self):
class MySchema(Schema):
class Meta:
unknown = fields.String()

name = fields.String()

result = MySchema().load({"name": "Joe", "extra": "hello"})
assert result == {"name": "Joe", "extra": "hello"}

def test_unknown_field_with_many(self):
class MySchema(Schema):
name = fields.String()

schema = MySchema(unknown=fields.Int())
result = schema.load(
[{"name": "Joe", "age": "42"}, {"name": "Jane", "score": "99"}],
many=True,
)
assert result == [{"name": "Joe", "age": 42}, {"name": "Jane", "score": 99}]

def test_unknown_field_in_load_kwarg(self):
class MySchema(Schema):
name = fields.String()

schema = MySchema()
result = schema.load({"name": "Joe", "extra": "42"}, unknown=fields.Int())
assert result == {"name": "Joe", "extra": 42}

def test_unknown_field_nested(self):
class ChildSchema(Schema):
num = fields.Int()

class ParentSchema(Schema):
child = fields.Nested(ChildSchema, unknown=fields.String())

data = {"child": {"num": 1, "extra": "hello"}}
result = ParentSchema().load(data)
assert result == {"child": {"num": 1, "extra": "hello"}}

def test_unknown_field_excludes_nothing(self):
class MySchema(Schema):
name = fields.String()

schema = MySchema(unknown=fields.Field())
result = schema.load({"name": "Joe", "extra": "value", "more": 123})
assert result == {"name": "Joe", "extra": "value", "more": 123}


def test_custom_unknown_error_message():
custom_message = "custom error message."

Expand Down