diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index c4a26041a..466d4e919 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -3489,6 +3489,116 @@ def arguments_schema( ) +class ArgumentsV3Parameter(TypedDict, total=False): + name: Required[str] + schema: Required[CoreSchema] + mode: Literal[ + 'positional_only', + 'positional_or_keyword', + 'keyword_only', + 'var_args', + 'var_kwargs_uniform', + 'var_kwargs_unpacked_typed_dict', + ] # default positional_or_keyword + alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] + + +def arguments_v3_parameter( + name: str, + schema: CoreSchema, + *, + mode: Literal[ + 'positional_only', + 'positional_or_keyword', + 'keyword_only', + 'var_args', + 'var_kwargs_uniform', + 'var_kwargs_unpacked_typed_dict', + ] + | None = None, + alias: str | list[str | int] | list[list[str | int]] | None = None, +) -> ArgumentsV3Parameter: + """ + Returns a schema that matches an argument parameter, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + param = core_schema.arguments_v3_parameter( + name='a', schema=core_schema.str_schema(), mode='positional_only' + ) + schema = core_schema.arguments_v3_schema([param]) + v = SchemaValidator(schema) + assert v.validate_python({'a': 'hello'}) == (('hello',), {}) + ``` + + Args: + name: The name to use for the argument parameter + schema: The schema to use for the argument parameter + mode: The mode to use for the argument parameter + alias: The alias to use for the argument parameter + """ + return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias) + + +class ArgumentsV3Schema(TypedDict, total=False): + type: Required[Literal['arguments-v3']] + arguments_schema: Required[list[ArgumentsV3Parameter]] + validate_by_name: bool + validate_by_alias: bool + var_args_schema: CoreSchema + var_kwargs_mode: VarKwargsMode + var_kwargs_schema: CoreSchema + ref: str + metadata: dict[str, Any] + serialization: SerSchema + + +def arguments_v3_schema( + arguments: list[ArgumentsV3Parameter], + *, + validate_by_name: bool | None = None, + validate_by_alias: bool | None = None, + ref: str | None = None, + metadata: dict[str, Any] | None = None, + serialization: SerSchema | None = None, +) -> ArgumentsV3Schema: + """ + Returns a schema that matches an arguments schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + param_a = core_schema.arguments_v3_parameter( + name='a', schema=core_schema.str_schema(), mode='positional_only' + ) + param_b = core_schema.arguments_v3_parameter( + name='kwargs', schema=core_schema.bool_schema(), mode='var_kwargs_uniform' + ) + schema = core_schema.arguments_v3_schema([param_a, param_b]) + v = SchemaValidator(schema) + assert v.validate_python({'a': 'hello', 'kwargs': {'extra': True}}) == (('hello',), {'extra': True}) + ``` + + Args: + arguments: The arguments to use for the arguments schema. + validate_by_name: Whether to populate by the parameter names, defaults to `False`. + validate_by_alias: Whether to populate by the parameter aliases, defaults to `True`. + ref: optional unique identifier of the schema, used to reference the schema in other places. + metadata: Any other information you want to include with the schema, not used by pydantic-core. + serialization: Custom serialization schema. + """ + return _dict_not_none( + type='arguments-v3', + arguments_schema=arguments, + validate_by_name=validate_by_name, + validate_by_alias=validate_by_alias, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + class CallSchema(TypedDict, total=False): type: Required[Literal['call']] arguments_schema: Required[CoreSchema] @@ -3916,6 +4026,7 @@ def definition_reference_schema( DataclassArgsSchema, DataclassSchema, ArgumentsSchema, + ArgumentsV3Schema, CallSchema, CustomErrorSchema, JsonSchema, diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index aa9c6eb16..119c862ac 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -81,6 +81,8 @@ pub trait Input<'py>: fmt::Debug { fn validate_args(&self) -> ValResult>; + fn validate_args_v3(&self) -> ValResult>; + fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult>; fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch>; @@ -265,6 +267,7 @@ pub trait ValidatedList<'py> { pub trait ValidatedTuple<'py> { type Item: BorrowInput<'py>; fn len(&self) -> Option; + fn try_for_each(self, f: impl FnMut(PyResult) -> ValResult<()>) -> ValResult<()>; fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult; } @@ -313,6 +316,9 @@ impl<'py> ValidatedTuple<'py> for Never { fn len(&self) -> Option { unreachable!() } + fn try_for_each(self, _f: impl FnMut(PyResult) -> ValResult<()>) -> ValResult<()> { + unreachable!() + } fn iterate(self, _consumer: impl ConsumeIterator, Output = R>) -> ValResult { unreachable!() } diff --git a/src/input/input_json.rs b/src/input/input_json.rs index 139c71a25..3fd906e16 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -84,6 +84,11 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> { } } + #[cfg_attr(has_coverage_attribute, coverage(off))] + fn validate_args_v3(&self) -> ValResult> { + Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) + } + fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult> { match self { JsonValue::Object(object) => Ok(JsonArgs::new(None, Some(object))), @@ -375,6 +380,11 @@ impl<'py> Input<'py> for str { Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) } + #[cfg_attr(has_coverage_attribute, coverage(off))] + fn validate_args_v3(&self) -> ValResult { + Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) + } + #[cfg_attr(has_coverage_attribute, coverage(off))] fn validate_dataclass_args(&self, class_name: &str) -> ValResult { let class_name = class_name.to_string(); @@ -571,6 +581,12 @@ impl<'a, 'data> ValidatedTuple<'_> for &'a JsonArray<'data> { fn len(&self) -> Option { Some(SmallVec::len(self)) } + fn try_for_each(self, mut f: impl FnMut(PyResult) -> ValResult<()>) -> ValResult<()> { + for item in self.iter() { + f(Ok(item))?; + } + Ok(()) + } fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { Ok(consumer.consume_iterator(self.iter().map(Ok))) } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index ea6eab054..e82cbaed7 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -117,6 +117,16 @@ impl<'py> Input<'py> for Bound<'py, PyAny> { } } + fn validate_args_v3(&self) -> ValResult> { + if let Ok(args_kwargs) = self.extract::() { + let args = args_kwargs.args.into_bound(self.py()); + let kwargs = args_kwargs.kwargs.map(|d| d.into_bound(self.py())); + Ok(PyArgs::new(Some(args), kwargs)) + } else { + Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) + } + } + fn validate_dataclass_args<'a>(&'a self, class_name: &str) -> ValResult> { if let Ok(dict) = self.downcast::() { Ok(PyArgs::new(None, Some(dict.clone()))) @@ -915,7 +925,15 @@ impl<'py> PySequenceIterable<'_, 'py> { PySequenceIterable::Iterator(iter) => iter.len().ok(), } } - + fn generic_try_for_each(self, f: impl FnMut(PyResult>) -> ValResult<()>) -> ValResult<()> { + match self { + PySequenceIterable::List(iter) => iter.iter().map(Ok).try_for_each(f), + PySequenceIterable::Tuple(iter) => iter.iter().map(Ok).try_for_each(f), + PySequenceIterable::Set(iter) => iter.iter().map(Ok).try_for_each(f), + PySequenceIterable::FrozenSet(iter) => iter.iter().map(Ok).try_for_each(f), + PySequenceIterable::Iterator(mut iter) => iter.try_for_each(f), + } + } fn generic_iterate( self, consumer: impl ConsumeIterator>, Output = R>, @@ -951,6 +969,9 @@ impl<'py> ValidatedTuple<'py> for PySequenceIterable<'_, 'py> { fn len(&self) -> Option { self.generic_len() } + fn try_for_each(self, f: impl FnMut(PyResult) -> ValResult<()>) -> ValResult<()> { + self.generic_try_for_each(f) + } fn iterate(self, consumer: impl ConsumeIterator, Output = R>) -> ValResult { self.generic_iterate(consumer) } diff --git a/src/input/input_string.rs b/src/input/input_string.rs index a50b3cff2..0ab4ad014 100644 --- a/src/input/input_string.rs +++ b/src/input/input_string.rs @@ -89,6 +89,11 @@ impl<'py> Input<'py> for StringMapping<'py> { Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) } + fn validate_args_v3(&self) -> ValResult> { + // do we want to support this? + Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)) + } + fn validate_dataclass_args<'a>(&'a self, _dataclass_name: &str) -> ValResult> { match self { StringMapping::String(_) => Err(ValError::new(ErrorTypeDefaults::ArgumentsType, self)), diff --git a/src/validators/arguments_v3.rs b/src/validators/arguments_v3.rs new file mode 100644 index 000000000..55a9aff6a --- /dev/null +++ b/src/validators/arguments_v3.rs @@ -0,0 +1,633 @@ +use std::str::FromStr; + +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyList, PyString, PyTuple}; + +use ahash::AHashSet; +use pyo3::IntoPyObjectExt; + +use crate::build_tools::py_schema_err; +use crate::build_tools::{schema_or_config_same, ExtraBehavior}; +use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult}; +use crate::input::{ + Arguments, BorrowInput, Input, KeywordArgs, PositionalArgs, ValidatedDict, ValidatedTuple, ValidationMatch, +}; +use crate::lookup_key::LookupKeyCollection; +use crate::tools::SchemaDict; + +use super::validation_state::ValidationState; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator}; + +#[derive(Debug, PartialEq)] +enum ParameterMode { + PositionalOnly, + PositionalOrKeyword, + VarArgs, + KeywordOnly, + VarKwargsUniform, + VarKwargsUnpackedTypedDict, +} + +impl FromStr for ParameterMode { + type Err = PyErr; + + fn from_str(s: &str) -> Result { + match s { + "positional_only" => Ok(Self::PositionalOnly), + "positional_or_keyword" => Ok(Self::PositionalOrKeyword), + "var_args" => Ok(Self::VarArgs), + "keyword_only" => Ok(Self::KeywordOnly), + "var_kwargs_uniform" => Ok(Self::VarKwargsUniform), + "var_kwargs_unpacked_typed_dict" => Ok(Self::VarKwargsUnpackedTypedDict), + s => py_schema_err!("Invalid var_kwargs mode: `{}`", s), + } + } +} + +#[derive(Debug)] +struct Parameter { + name: String, + mode: ParameterMode, + lookup_key_collection: LookupKeyCollection, + validator: CombinedValidator, +} + +impl Parameter { + fn is_variadic(&self) -> bool { + matches!( + self.mode, + ParameterMode::VarArgs | ParameterMode::VarKwargsUniform | ParameterMode::VarKwargsUnpackedTypedDict + ) + } +} + +#[derive(Debug)] +pub struct ArgumentsV3Validator { + parameters: Vec, + positional_params_count: usize, + loc_by_alias: bool, + extra: ExtraBehavior, + validate_by_alias: Option, + validate_by_name: Option, +} + +impl BuildValidator for ArgumentsV3Validator { + const EXPECTED_TYPE: &'static str = "arguments-v3"; + + fn build( + schema: &Bound<'_, PyDict>, + config: Option<&Bound<'_, PyDict>>, + definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + + let arguments_schema: Bound<'_, PyList> = schema.get_as_req(intern!(py, "arguments_schema"))?; + let mut parameters: Vec = Vec::with_capacity(arguments_schema.len()); + + let mut had_default_arg = false; + let mut had_keyword_only = false; + + for arg in arguments_schema.iter() { + let arg = arg.downcast::()?; + + let py_name: Bound = arg.get_as_req(intern!(py, "name"))?; + let name = py_name.to_string(); + let py_mode = arg.get_as::>(intern!(py, "mode"))?; + let py_mode = py_mode + .as_ref() + .map(|py_str| py_str.to_str()) + .transpose()? + .unwrap_or("positional_or_keyword"); + + let mode = ParameterMode::from_str(py_mode)?; + + // let positional = mode == "positional_only" || mode == "positional_or_keyword"; + // if positional { + // positional_params_count = arg_index + 1; + // } + + if mode == ParameterMode::KeywordOnly { + had_keyword_only = true; + } + + let schema = arg.get_as_req(intern!(py, "schema"))?; + + let validator = match build_validator(&schema, config, definitions) { + Ok(v) => v, + Err(err) => return py_schema_err!("Parameter '{}':\n {}", name, err), + }; + + let has_default = match validator { + CombinedValidator::WithDefault(ref v) => { + if v.omit_on_error() { + return py_schema_err!("Parameter '{}': omit_on_error cannot be used with arguments", name); + } + v.has_default() + } + _ => false, + }; + + if had_default_arg && !has_default && !had_keyword_only { + return py_schema_err!("Non-default argument '{}' follows default argument", name); + } else if has_default { + had_default_arg = true; + } + + let validation_alias = arg.get_item(intern!(py, "alias"))?; + let lookup_key_collection = LookupKeyCollection::new(py, validation_alias, name.as_str())?; + + parameters.push(Parameter { + name, + mode, + lookup_key_collection, + validator, + }); + } + + let positional_params_count = parameters + .iter() + .filter(|p| { + matches!( + p.mode, + ParameterMode::PositionalOnly | ParameterMode::PositionalOrKeyword + ) + }) + .count(); + + Ok(Self { + parameters, + positional_params_count, + loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), + extra: ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Forbid)?, + validate_by_alias: schema_or_config_same(schema, config, intern!(py, "validate_by_alias"))?, + validate_by_name: schema_or_config_same(schema, config, intern!(py, "validate_by_name"))?, + } + .into()) + } +} + +impl_py_gc_traverse!(Parameter { validator }); + +impl_py_gc_traverse!(ArgumentsV3Validator { parameters }); + +impl ArgumentsV3Validator { + /// Validate the arguments from a mapping: + /// ```py + /// def func(a: int, /, *, b: str, **kwargs: int) -> None: + /// ... + /// + /// valid_mapping = {'a': 1, 'b': 'test', 'kwargs': {'c': 1, 'd': 2}} + /// ``` + fn validate_from_mapping<'py>( + &self, + py: Python<'py>, + original_input: &(impl Input<'py> + ?Sized), + mapping: impl ValidatedDict<'py>, + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let mut output_args: Vec = Vec::with_capacity(self.positional_params_count); + let output_kwargs = PyDict::new(py); + let mut errors: Vec = Vec::new(); + + let validate_by_alias = state.validate_by_alias_or(self.validate_by_alias); + let validate_by_name = state.validate_by_name_or(self.validate_by_name); + + for parameter in &self.parameters { + let lookup_key = parameter + .lookup_key_collection + .select(validate_by_alias, validate_by_name)?; + // A value is present in the mapping: + if let Some((lookup_path, dict_value)) = mapping.get_item(lookup_key)? { + match parameter.mode { + ParameterMode::PositionalOnly | ParameterMode::PositionalOrKeyword => { + match parameter.validator.validate(py, dict_value.borrow_input(), state) { + Ok(value) => output_args.push(value), + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors.into_iter().map(|err| { + lookup_path.apply_error_loc(err, self.loc_by_alias, ¶meter.name) + }), + ); + } + Err(err) => return Err(err), + } + } + ParameterMode::VarArgs => match dict_value.borrow_input().validate_tuple(false) { + Ok(tuple) => { + tuple.unpack(state).try_for_each(|v| { + match parameter.validator.validate(py, v.unwrap().borrow_input(), state) { + Ok(tuple_value) => { + output_args.push(tuple_value); + Ok(()) + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| { + lookup_path.apply_error_loc(err, self.loc_by_alias, ¶meter.name) + })); + Ok(()) + } + Err(err) => Err(err), + } + })?; + } + Err(_) => { + let val_error = ValLineError::new(ErrorTypeDefaults::TupleType, dict_value.borrow_input()); + errors.push(lookup_path.apply_error_loc(val_error, self.loc_by_alias, ¶meter.name)); + } + }, + ParameterMode::KeywordOnly => { + match parameter.validator.validate(py, dict_value.borrow_input(), state) { + Ok(value) => { + output_kwargs.set_item(PyString::new(py, parameter.name.as_str()).unbind(), value)?; + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors.into_iter().map(|err| { + lookup_path.apply_error_loc(err, self.loc_by_alias, ¶meter.name) + }), + ); + } + Err(err) => return Err(err), + } + } + ParameterMode::VarKwargsUniform => match dict_value.borrow_input().as_kwargs(py) { + // We will validate that keys are strings, and values match the validator: + Some(value) => { + for (dict_key, dict_value) in value { + // Validate keys are strings: + match dict_key.validate_str(true, false).map(ValidationMatch::into_inner) { + Ok(_) => (), + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push( + err.with_outer_location(dict_key.clone()) + .with_outer_location(¶meter.name) + .with_type(ErrorTypeDefaults::InvalidKey), + ); + } + continue; + } + Err(err) => return Err(err), + } + // Validate values: + match parameter.validator.validate(py, dict_value.borrow_input(), state) { + Ok(value) => output_kwargs.set_item(dict_key, value)?, + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| { + lookup_path.apply_error_loc( + err.with_outer_location(dict_key.clone()), + self.loc_by_alias, + ¶meter.name, + ) + })); + } + Err(err) => return Err(err), + } + } + } + None => { + let val_error = ValLineError::new(ErrorTypeDefaults::DictType, dict_value); + errors.push(lookup_path.apply_error_loc(val_error, self.loc_by_alias, ¶meter.name)); + } + }, + ParameterMode::VarKwargsUnpackedTypedDict => { + let kwargs_dict = dict_value + .borrow_input() + .as_kwargs(py) + .unwrap_or_else(|| PyDict::new(py)); + match parameter.validator.validate(py, kwargs_dict.as_any(), state) { + Ok(value) => { + output_kwargs.update(value.downcast_bound::(py).unwrap().as_mapping())?; + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors); + } + Err(err) => return Err(err), + } + } + } + // No value is present in the mapping, fallback to the default value (and error if no default): + } else { + match parameter.mode { + ParameterMode::PositionalOnly | ParameterMode::PositionalOrKeyword | ParameterMode::KeywordOnly => { + if let Some(value) = + parameter + .validator + .default_value(py, Some(parameter.name.as_str()), state)? + { + if parameter.mode == ParameterMode::PositionalOnly { + output_args.push(value); + } else { + output_kwargs.set_item(PyString::new(py, parameter.name.as_str()).unbind(), value)?; + } + } else { + let error_type = match parameter.mode { + ParameterMode::PositionalOnly => ErrorTypeDefaults::MissingPositionalOnlyArgument, + ParameterMode::PositionalOrKeyword => ErrorTypeDefaults::MissingArgument, + ParameterMode::KeywordOnly => ErrorTypeDefaults::MissingKeywordOnlyArgument, + _ => unreachable!(), + }; + + errors.push(lookup_key.error( + error_type, + original_input, + self.loc_by_alias, + ¶meter.name, + )); + } + } + // Variadic args/kwargs can be empty by definition: + _ => (), + } + } + } + + if !errors.is_empty() { + Err(ValError::LineErrors(errors)) + } else { + Ok((PyTuple::new(py, output_args)?, output_kwargs).into_py_any(py)?) + } + } + + /// Validate the arguments from an [`ArgsKwargs`][crate::argument_markers::ArgsKwargs] instance: + /// ```py + /// def func(a: int, /, *, b: str, **kwargs: int) -> None: + /// ... + /// + /// valid_argskwargs = ArgsKwargs((1,), {'b': 'test', 'c': 1, 'd': 2}) + /// ``` + fn validate_from_argskwargs<'py>( + &self, + py: Python<'py>, + original_input: &(impl Input<'py> + ?Sized), + args_kwargs: impl Arguments<'py>, + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + let mut output_args: Vec = Vec::with_capacity(self.positional_params_count); + let output_kwargs = PyDict::new(py); + let mut errors: Vec = Vec::new(); + let mut used_kwargs: AHashSet<&str> = AHashSet::with_capacity(self.parameters.len()); + + let validate_by_alias = state.validate_by_alias_or(self.validate_by_alias); + let validate_by_name = state.validate_by_name_or(self.validate_by_name); + + // go through non variadic arguments, getting the value from args or kwargs and validating it + for (index, parameter) in self.parameters.iter().filter(|p| !p.is_variadic()).enumerate() { + let lookup_key = parameter + .lookup_key_collection + .select(validate_by_alias, validate_by_name)?; + + let mut pos_value = None; + if let Some(args) = args_kwargs.args() { + if matches!( + parameter.mode, + ParameterMode::PositionalOnly | ParameterMode::PositionalOrKeyword + ) { + pos_value = args.get_item(index); + } + } + + let mut kw_value = None; + if let Some(kwargs) = args_kwargs.kwargs() { + if matches!( + parameter.mode, + ParameterMode::PositionalOrKeyword | ParameterMode::KeywordOnly + ) { + if let Some((lookup_path, value)) = kwargs.get_item(lookup_key)? { + used_kwargs.insert(lookup_path.first_key()); + kw_value = Some((lookup_path, value)); + } + } + } + + match (pos_value, kw_value) { + (Some(_), Some((_, kw_value))) => { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::MultipleArgumentValues, + kw_value.borrow_input(), + parameter.name.clone(), + )); + } + (Some(pos_value), None) => match parameter.validator.validate(py, pos_value.borrow_input(), state) { + Ok(value) => output_args.push(value), + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); + } + Err(err) => return Err(err), + }, + (None, Some((lookup_path, kw_value))) => { + match parameter.validator.validate(py, kw_value.borrow_input(), state) { + Ok(value) => { + output_kwargs.set_item(PyString::new(py, parameter.name.as_str()).unbind(), value)?; + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend( + line_errors + .into_iter() + .map(|err| lookup_path.apply_error_loc(err, self.loc_by_alias, ¶meter.name)), + ); + } + Err(err) => return Err(err), + } + } + (None, None) => { + if let Some(value) = parameter + .validator + .default_value(py, Some(parameter.name.as_str()), state)? + { + if matches!( + parameter.mode, + ParameterMode::PositionalOnly | ParameterMode::PositionalOrKeyword + ) { + output_kwargs.set_item(PyString::new(py, parameter.name.as_str()).unbind(), value)?; + } else { + output_args.push(value); + } + } else { + // Required and no default, error: + match parameter.mode { + ParameterMode::PositionalOnly => { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::MissingPositionalOnlyArgument, + original_input, + index, + )); + } + ParameterMode::PositionalOrKeyword => { + errors.push(lookup_key.error( + ErrorTypeDefaults::MissingArgument, + original_input, + self.loc_by_alias, + ¶meter.name, + )); + } + ParameterMode::KeywordOnly => { + errors.push(lookup_key.error( + ErrorTypeDefaults::MissingKeywordOnlyArgument, + original_input, + self.loc_by_alias, + ¶meter.name, + )); + } + _ => unreachable!(), + } + } + } + } + } + + // if there are args check any where index > positional_params_count since they won't have been checked yet + if let Some(args) = args_kwargs.args() { + let len = args.len(); + if len > self.positional_params_count { + if let Some(var_args_param) = self.parameters.iter().find(|p| p.mode == ParameterMode::VarArgs) { + for (index, item) in args.iter().enumerate().skip(self.positional_params_count) { + match var_args_param.validator.validate(py, item.borrow_input(), state) { + Ok(value) => output_args.push(value), + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); + } + Err(err) => return Err(err), + } + } + } else { + for (index, item) in args.iter().enumerate().skip(self.positional_params_count) { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedPositionalArgument, + item, + index, + )); + } + } + } + } + + let remaining_kwargs = PyDict::new(py); + + // if there are kwargs check any that haven't been processed yet + if let Some(kwargs) = args_kwargs.kwargs() { + if kwargs.len() > used_kwargs.len() { + for result in kwargs.iter() { + let (raw_key, value) = result?; + let either_str = match raw_key + .borrow_input() + .validate_str(true, false) + .map(ValidationMatch::into_inner) + { + Ok(k) => k, + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push( + err.with_outer_location(raw_key.clone()) + .with_type(ErrorTypeDefaults::InvalidKey), + ); + } + continue; + } + Err(err) => return Err(err), + }; + if !used_kwargs.contains(either_str.as_cow()?.as_ref()) { + let maybe_var_kwargs_parameter = self.parameters.iter().find(|p| { + matches!( + p.mode, + ParameterMode::VarKwargsUniform | ParameterMode::VarKwargsUnpackedTypedDict + ) + }); + + match maybe_var_kwargs_parameter { + None => { + if self.extra == ExtraBehavior::Forbid { + errors.push(ValLineError::new_with_loc( + ErrorTypeDefaults::UnexpectedKeywordArgument, + value, + raw_key.clone(), + )); + } + } + Some(var_kwargs_parameter) => { + match var_kwargs_parameter.mode { + ParameterMode::VarKwargsUniform => { + match var_kwargs_parameter.validator.validate(py, value.borrow_input(), state) { + Ok(value) => { + output_kwargs + .set_item(either_str.as_py_string(py, state.cache_str()), value)?; + } + Err(ValError::LineErrors(line_errors)) => { + for err in line_errors { + errors.push(err.with_outer_location(raw_key.clone())); + } + } + Err(err) => return Err(err), + } + } + ParameterMode::VarKwargsUnpackedTypedDict => { + // Save to the remaining kwargs, we will validate as a single dict: + remaining_kwargs.set_item( + either_str.as_py_string(py, state.cache_str()), + value.borrow_input().to_object(py)?, + )?; + } + _ => unreachable!(), + } + } + } + } + } + } + } + + if !remaining_kwargs.is_empty() { + // In this case, the unpacked typeddict var kwargs parameter is guaranteed to exist: + let var_kwargs_parameter = self + .parameters + .iter() + .find(|p| p.mode == ParameterMode::VarKwargsUnpackedTypedDict) + .unwrap(); + match var_kwargs_parameter + .validator + .validate(py, remaining_kwargs.as_any(), state) + { + Ok(value) => { + output_kwargs.update(value.downcast_bound::(py).unwrap().as_mapping())?; + } + Err(ValError::LineErrors(line_errors)) => { + errors.extend(line_errors); + } + Err(err) => return Err(err), + } + } + + if !errors.is_empty() { + Err(ValError::LineErrors(errors)) + } else { + Ok((PyTuple::new(py, output_args)?, output_kwargs).into_py_any(py)?) + } + } +} + +impl Validator for ArgumentsV3Validator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + // this validator does not yet support partial validation, disable it to avoid incorrect results + state.allow_partial = false.into(); + + let args_dict = input.validate_dict(false); + + // Validation from a dictionary, mapping parameter names to the values: + if let Ok(dict) = args_dict { + self.validate_from_mapping(py, input, dict, state) + } else { + let args = input.validate_args_v3()?; + self.validate_from_argskwargs(py, input, args, state) + } + } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 4e40bf080..f105e1854 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -21,6 +21,7 @@ pub(crate) use config::ValBytesMode; mod any; mod arguments; +mod arguments_v3; mod bool; mod bytes; mod call; @@ -636,6 +637,7 @@ pub fn build_validator( callable::CallableValidator, // arguments arguments::ArgumentsValidator, + arguments_v3::ArgumentsV3Validator, // default value with_default::WithDefaultValidator, // chain validators @@ -802,6 +804,7 @@ pub enum CombinedValidator { Callable(callable::CallableValidator), // arguments Arguments(arguments::ArgumentsValidator), + ArgumentsV3(arguments_v3::ArgumentsV3Validator), // default value WithDefault(with_default::WithDefaultValidator), // chain validators diff --git a/tests/conftest.py b/tests/conftest.py index 56d94a0c1..226ace018 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ import hypothesis import pytest -from pydantic_core import ArgsKwargs, SchemaValidator, ValidationError, validate_core_schema +from pydantic_core import ArgsKwargs, CoreSchema, SchemaValidator, ValidationError, validate_core_schema from pydantic_core.core_schema import CoreConfig __all__ = 'Err', 'PyAndJson', 'plain_repr', 'infinite_generator' @@ -52,7 +52,11 @@ def json_default(obj): class PyAndJsonValidator: def __init__( - self, schema, config: CoreConfig | None = None, *, validator_type: Literal['json', 'python'] | None = None + self, + schema: CoreSchema, + config: CoreConfig | None = None, + *, + validator_type: Literal['json', 'python'] | None = None, ): self.validator = SchemaValidator(validate_core_schema(schema), config) self.validator_type = validator_type diff --git a/tests/validators/test_arguments_v3.py b/tests/validators/test_arguments_v3.py new file mode 100644 index 000000000..d5b099e46 --- /dev/null +++ b/tests/validators/test_arguments_v3.py @@ -0,0 +1,82 @@ +import pytest + +from pydantic_core import ArgsKwargs, ValidationError +from pydantic_core import core_schema as cs + +from ..conftest import PyAndJson + + +@pytest.mark.parametrize( + ['input_value', 'expected'], + ( + [ArgsKwargs((1, True), {}), ((1, True), {})], + [ArgsKwargs((1,), {}), ((1,), {})], + [{'a': 1, 'b': True}, ((1, True), {})], + [{'a': 1}, ((1,), {})], + ), +) +def test_positional_only(py_and_json: PyAndJson, input_value, expected) -> None: + v = py_and_json( + cs.arguments_v3_schema( + [ + cs.arguments_v3_parameter(name='a', schema=cs.int_schema(), mode='positional_only'), + cs.arguments_v3_parameter( + name='b', schema=cs.with_default_schema(cs.bool_schema()), mode='positional_only' + ), + ] + ) + ) + + assert v.validate_test(input_value) == expected + + +def test_positional_only_validation_error(py_and_json: PyAndJson) -> None: + v = py_and_json( + cs.arguments_v3_schema( + [ + cs.arguments_v3_parameter(name='a', schema=cs.int_schema(), mode='positional_only'), + ] + ) + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_test(ArgsKwargs(('not_an_int',), {})) + + error = exc_info.value.errors()[0] + + assert error['type'] == 'int_parsing' + assert error['loc'] == (0,) + + with pytest.raises(ValidationError) as exc_info: + v.validate_test({'a': 'not_an_int'}) + + error = exc_info.value.errors()[0] + + assert error['type'] == 'int_parsing' + assert error['loc'] == ('a',) + + +def test_positional_only_error_required(py_and_json: PyAndJson) -> None: + v = py_and_json( + cs.arguments_v3_schema( + [ + cs.arguments_v3_parameter(name='a', schema=cs.int_schema(), mode='positional_only'), + ] + ) + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_test(ArgsKwargs(tuple(), {})) + + error = exc_info.value.errors()[0] + + assert error['type'] == 'missing_positional_only_argument' + assert error['loc'] == (0,) + + with pytest.raises(ValidationError) as exc_info: + v.validate_test({}) + + error = exc_info.value.errors()[0] + + assert error['type'] == 'missing_positional_only_argument' + assert error['loc'] == ('a',)