From 6a5830b550278031d6bee55b8541ebb027767ebe Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Thu, 13 Feb 2025 16:11:34 +1100 Subject: [PATCH] c --- .github/workflows/test-python.yml | 4 +- crates/polars-core/src/serde/df.rs | 2 +- crates/polars-core/src/serde/series.rs | 2 +- .../src/cloud/credential_provider.rs | 147 ++++--- crates/polars-io/src/cloud/options.rs | 35 +- .../parquet/metadata/column_chunk_metadata.rs | 2 +- crates/polars-plan/src/dsl/expr_dyn_fn.rs | 4 +- crates/polars-python/src/catalog/unity.rs | 2 +- crates/polars-python/src/dataframe/io.rs | 6 +- crates/polars-python/src/lazyframe/general.rs | 16 +- crates/polars-utils/src/pl_serialize.rs | 2 +- crates/polars-utils/src/python_function.rs | 73 ++-- py-polars/polars/_utils/logging.py | 10 + py-polars/polars/catalog/unity/client.py | 33 +- py-polars/polars/dataframe/frame.py | 49 ++- py-polars/polars/io/cloud/__init__.py | 2 +- .../io/cloud/credential_provider/__init__.py | 17 + .../io/cloud/credential_provider/_builder.py | 382 ++++++++++++++++++ .../_providers.py} | 245 ++--------- py-polars/polars/io/csv/functions.py | 12 +- py-polars/polars/io/delta.py | 22 +- py-polars/polars/io/ipc/functions.py | 9 +- py-polars/polars/io/ndjson.py | 16 +- py-polars/polars/io/parquet/functions.py | 11 +- py-polars/polars/lazyframe/frame.py | 36 +- .../unit/io/cloud/test_credential_provider.py | 57 ++- 26 files changed, 803 insertions(+), 393 deletions(-) create mode 100644 py-polars/polars/_utils/logging.py create mode 100644 py-polars/polars/io/cloud/credential_provider/__init__.py create mode 100644 py-polars/polars/io/cloud/credential_provider/_builder.py rename py-polars/polars/io/cloud/{credential_provider.py => credential_provider/_providers.py} (63%) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 97fd4173add4..41fcc1cd3d2b 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -23,7 +23,7 @@ concurrency: cancel-in-progress: true env: - RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down + RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down RUST_BACKTRACE: 1 PYTHONUTF8: 1 @@ -39,7 +39,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ['3.9', '3.12', '3.13'] + python-version: ['3.9', '3.12', '3.13', '3.12.8', '3.13.1'] include: - os: windows-latest python-version: '3.13' diff --git a/crates/polars-core/src/serde/df.rs b/crates/polars-core/src/serde/df.rs index e0fe9053bf19..8dbc3636abbf 100644 --- a/crates/polars-core/src/serde/df.rs +++ b/crates/polars-core/src/serde/df.rs @@ -166,7 +166,7 @@ impl<'de> Deserialize<'de> for DataFrame { where D: Deserializer<'de>, { - deserialize_map_bytes(deserializer, &mut |b| { + deserialize_map_bytes(deserializer, |b| { let v = &mut b.as_ref(); Self::deserialize_from_reader(v) })? diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index 0df7c19fde2a..db9080a95737 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -57,7 +57,7 @@ impl<'de> Deserialize<'de> for Series { where D: Deserializer<'de>, { - deserialize_map_bytes(deserializer, &mut |b| { + deserialize_map_bytes(deserializer, |b| { let v = &mut b.as_ref(); Self::deserialize_from_reader(v) })? diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs index 5cd366fb1a88..6c6c9344a965 100644 --- a/crates/polars-io/src/cloud/credential_provider.rs +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -15,8 +15,6 @@ pub use object_store::gcp::GcpCredential; use polars_core::config; use polars_error::{polars_bail, PolarsResult}; #[cfg(feature = "python")] -use polars_utils::python_function::PythonFunction; -#[cfg(feature = "python")] use python_impl::PythonCredentialProvider; #[derive(Clone, Debug, PartialEq, Hash, Eq)] @@ -43,23 +41,36 @@ impl PlCredentialProvider { Self::Function(CredentialProviderFunction(Arc::new(func))) } + /// Intended to be called with an internal `CredentialProviderBuilder`. #[cfg(feature = "python")] - pub fn from_python_func(func: PythonFunction) -> Self { - Self::Python(python_impl::PythonCredentialProvider(Arc::new(func))) - } + pub fn from_python_builder(func: pyo3::PyObject) -> Self { + use polars_utils::python_function::PythonObject; - #[cfg(feature = "python")] - pub fn from_python_func_object(func: pyo3::PyObject) -> Self { - Self::Python(python_impl::PythonCredentialProvider(Arc::new( - PythonFunction(func), - ))) + Self::Python(python_impl::PythonCredentialProvider::from_builder( + Arc::new(PythonObject(func)), + )) } pub(super) fn func_addr(&self) -> usize { match self { Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize, #[cfg(feature = "python")] - Self::Python(PythonCredentialProvider(v)) => Arc::as_ptr(v) as *const () as usize, + Self::Python(PythonCredentialProvider { + // We know this is only used for hashing, it is safe to ignore `is_builder`, since we + // don't expect that the same py_object can be both a builder and provider. + py_object, + is_builder: _, + }) => Arc::as_ptr(py_object) as *const () as usize, + } + } + + /// Python passes a `CredentialProviderBuilder`, this calls the builder to + /// build the final credential provider. + pub(crate) fn try_into_initialized(self) -> PolarsResult> { + match self { + Self::Function(_) => Ok(Some(self)), + #[cfg(feature = "python")] + Self::Python(v) => Ok(v.try_into_initialized()?.map(Self::Python)), } } } @@ -452,8 +463,8 @@ mod python_impl { use std::hash::Hash; use std::sync::Arc; - use polars_error::PolarsError; - use polars_utils::python_function::PythonFunction; + use polars_error::{to_compute_err, PolarsError, PolarsResult}; + use polars_utils::python_function::PythonObject; use pyo3::exceptions::PyValueError; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods}; @@ -462,11 +473,63 @@ mod python_impl { use super::IntoCredentialProvider; #[derive(Clone, Debug)] - pub struct PythonCredentialProvider(pub(super) Arc); + #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] + pub struct PythonCredentialProvider { + pub(super) py_object: Arc, + /// Indicates `py_object` is a `CredentialProviderBuilder`. + pub(super) is_builder: bool, + } + + impl PythonCredentialProvider { + pub(crate) fn from_builder(py_object: Arc) -> Self { + if cfg!(debug_assertions) { + Python::with_gil(|py| { + let cls_name = py_object + .getattr(py, "__class__") + .unwrap() + .getattr(py, "__name__") + .unwrap() + .extract::(py) + .unwrap(); + + assert_eq!(&cls_name, "CredentialProviderBuilder"); + }); + } + + Self { + py_object, + is_builder: true, + } + } + + pub(crate) fn from_provider(py_object: Arc) -> Self { + Self { + py_object, + is_builder: false, + } + } + + /// Performs initialization if necessary + pub(crate) fn try_into_initialized(self) -> PolarsResult> { + if self.is_builder { + let opt_initialized_py_object = Python::with_gil(|py| { + let build_fn = self.py_object.getattr(py, "build_credential_provider")?; + + let v = build_fn.call0(py)?; + let v = (!v.is_none(py)).then_some(v); - impl From for PythonCredentialProvider { - fn from(value: PythonFunction) -> Self { - Self(Arc::new(value)) + pyo3::PyResult::Ok(v) + }) + .map_err(to_compute_err)?; + + Ok(opt_initialized_py_object + .map(PythonObject) + .map(Arc::new) + .map(Self::from_provider)) + } else { + // Note: We don't expect to hit here. + Ok(Some(self)) + } } } @@ -479,8 +542,12 @@ mod python_impl { CredentialProviderFunction, ObjectStoreCredential, }; + assert!(!self.is_builder); // should not be a builder at this point. + + let func = self.py_object; + CredentialProviderFunction(Arc::new(move || { - let func = self.0.clone(); + let func = func.clone(); Box::pin(async move { let mut credentials = object_store::aws::AwsCredential { key_id: String::new(), @@ -554,8 +621,12 @@ mod python_impl { CredentialProviderFunction, ObjectStoreCredential, }; + assert!(!self.is_builder); // should not be a builder at this point. + + let func = self.py_object; + CredentialProviderFunction(Arc::new(move || { - let func = self.0.clone(); + let func = func.clone(); Box::pin(async move { let mut credentials = None; @@ -621,8 +692,11 @@ mod python_impl { CredentialProviderFunction, ObjectStoreCredential, }; + assert!(!self.is_builder); // should not be a builder at this point. + + let func = self.py_object; CredentialProviderFunction(Arc::new(move || { - let func = self.0.clone(); + let func = func.clone(); Box::pin(async move { let mut credentials = object_store::gcp::GcpCredential { bearer: String::new(), @@ -666,11 +740,14 @@ mod python_impl { } } + // Note: We don't consider `is_builder` for hash/eq - we don't expect the same Arc + // to be referenced as both true and false from the `is_builder` field. + impl Eq for PythonCredentialProvider {} impl PartialEq for PythonCredentialProvider { fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.0, &other.0) + Arc::ptr_eq(&self.py_object, &other.py_object) } } @@ -680,33 +757,7 @@ mod python_impl { // * Inner is an `Arc` // * Visibility is limited to super // * No code in `mod python_impl` or `super` mutates the Arc inner. - state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) - } - } - - #[cfg(feature = "serde")] - mod _serde_impl { - use polars_utils::python_function::PySerializeWrap; - - use super::PythonCredentialProvider; - - impl serde::Serialize for PythonCredentialProvider { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - PySerializeWrap(self.0.as_ref()).serialize(serializer) - } - } - - impl<'a> serde::Deserialize<'a> for PythonCredentialProvider { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'a>, - { - PySerializeWrap::::deserialize(deserializer) - .map(|x| x.0.into()) - } + state.write_usize(Arc::as_ptr(&self.py_object) as *const () as usize) } } } diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index 2141bd9cab2f..399133a75269 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -29,8 +29,6 @@ use regex::Regex; #[cfg(feature = "http")] use reqwest::header::HeaderMap; #[cfg(feature = "serde")] -use serde::Deserializer; -#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; #[cfg(feature = "cloud")] use url::Url; @@ -80,19 +78,11 @@ pub struct CloudOptions { pub file_cache_ttl: u64, pub(crate) config: Option, #[cfg(feature = "cloud")] - #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_or_default"))] + /// Note: In most cases you will want to access this via [`CloudOptions::initialized_credential_provider`] + /// rather than directly. pub(crate) credential_provider: Option, } -#[cfg(all(feature = "serde", feature = "cloud"))] -fn deserialize_or_default<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - type T = Option; - T::deserialize(deserializer).or_else(|_| Ok(Default::default())) -} - impl Default for CloudOptions { fn default() -> Self { Self::default_static_ref().clone() @@ -392,7 +382,7 @@ impl CloudOptions { let builder = builder.with_retry(get_retry_config(self.max_retries)); - let builder = if let Some(v) = self.credential_provider.clone() { + let builder = if let Some(v) = self.initialized_credential_provider()? { builder.with_credentials(v.into_aws_provider()) } else { builder @@ -438,7 +428,7 @@ impl CloudOptions { .with_url(url) .with_retry(get_retry_config(self.max_retries)); - let builder = if let Some(v) = self.credential_provider.clone() { + let builder = if let Some(v) = self.initialized_credential_provider()? { if verbose { eprintln!( "[CloudOptions::build_azure]: Using credential provider {:?}", @@ -470,7 +460,9 @@ impl CloudOptions { pub fn build_gcp(&self, url: &str) -> PolarsResult { use super::credential_provider::IntoCredentialProvider; - let builder = if self.credential_provider.is_none() { + let credential_provider = self.initialized_credential_provider()?; + + let builder = if credential_provider.is_none() { GoogleCloudStorageBuilder::from_env() } else { GoogleCloudStorageBuilder::new() @@ -491,7 +483,7 @@ impl CloudOptions { .with_url(url) .with_retry(get_retry_config(self.max_retries)); - let builder = if let Some(v) = self.credential_provider.clone() { + let builder = if let Some(v) = credential_provider.clone() { builder.with_credentials(v.into_gcp_provider()) } else { builder @@ -629,6 +621,17 @@ impl CloudOptions { }, } } + + /// Python passes a credential provider builder that needs to be called to get the actual credential + /// provider. + #[cfg(feature = "cloud")] + fn initialized_credential_provider(&self) -> PolarsResult> { + if let Some(v) = self.credential_provider.clone() { + v.try_into_initialized() + } else { + Ok(None) + } + } } #[cfg(feature = "cloud")] diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs index 71db1a74601a..9fc241d5c51c 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -65,7 +65,7 @@ where { use polars_utils::pl_serialize::deserialize_map_bytes; - deserialize_map_bytes(deserializer, &mut |b| { + deserialize_map_bytes(deserializer, |b| { let mut b = b.as_ref(); let mut protocol = TCompactInputProtocol::new(&mut b, usize::MAX); ColumnChunk::read_from_in_protocol(&mut protocol).map_err(D::Error::custom) diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 7dc3b0cc20c9..9d72e2673251 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -71,7 +71,7 @@ impl<'a> Deserialize<'a> for SpecialEq> { use serde::de::Error; #[cfg(feature = "python")] { - deserialize_map_bytes(deserializer, &mut |buf| { + deserialize_map_bytes(deserializer, |buf| { if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) { let udf = crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; @@ -407,7 +407,7 @@ impl<'a> Deserialize<'a> for GetOutput { use serde::de::Error; #[cfg(feature = "python")] { - deserialize_map_bytes(deserializer, &mut |buf| { + deserialize_map_bytes(deserializer, |buf| { if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) { let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf) .map_err(|e| D::Error::custom(format!("{e}")))?; diff --git a/crates/polars-python/src/catalog/unity.rs b/crates/polars-python/src/catalog/unity.rs index 41928f239f0b..e26abaedbaf8 100644 --- a/crates/polars-python/src/catalog/unity.rs +++ b/crates/polars-python/src/catalog/unity.rs @@ -262,7 +262,7 @@ impl PyCatalogClient { parse_cloud_options(storage_location, cloud_options.unwrap_or_default())? .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ); Ok( diff --git a/crates/polars-python/src/dataframe/io.rs b/crates/polars-python/src/dataframe/io.rs index a7c15bd180f6..9cdf1eeb413d 100644 --- a/crates/polars-python/src/dataframe/io.rs +++ b/crates/polars-python/src/dataframe/io.rs @@ -362,7 +362,7 @@ impl PyDataFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ), ) } else { @@ -424,7 +424,7 @@ impl PyDataFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ), ) } else { @@ -517,7 +517,7 @@ impl PyDataFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ), ) } else { diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 06c303e7b437..75fd54706c3f 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -86,7 +86,7 @@ impl PyLazyFrame { cloud_options = cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ); if let Some(file_cache_ttl) = file_cache_ttl { @@ -206,7 +206,7 @@ impl PyLazyFrame { cloud_options = cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ); r = r.with_cloud_options(Some(cloud_options)); } @@ -341,7 +341,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ), ); } @@ -417,7 +417,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(PlCredentialProvider::from_python_func_object), + credential_provider.map(PlCredentialProvider::from_python_builder), ), ); } @@ -715,7 +715,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_builder), ), ) }; @@ -748,7 +748,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_builder), ), ) }; @@ -819,7 +819,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_builder), ), ) }; @@ -854,7 +854,7 @@ impl PyLazyFrame { cloud_options .with_max_retries(retries) .with_credential_provider( - credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_func_object), + credential_provider.map(polars::prelude::cloud::credential_provider::PlCredentialProvider::from_python_builder), ), ) }; diff --git a/crates/polars-utils/src/pl_serialize.rs b/crates/polars-utils/src/pl_serialize.rs index 04079943a2be..e6947563509d 100644 --- a/crates/polars-utils/src/pl_serialize.rs +++ b/crates/polars-utils/src/pl_serialize.rs @@ -105,7 +105,7 @@ where /// This is essentially boilerplate for visiting bytes without copying where possible. pub fn deserialize_map_bytes<'de, D, O>( deserializer: D, - func: &mut (dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O), + mut func: impl for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O, ) -> Result where D: serde::de::Deserializer<'de>, diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs index b9798b83133d..064bcd13b728 100644 --- a/crates/polars-utils/src/python_function.rs +++ b/crates/polars-utils/src/python_function.rs @@ -10,24 +10,42 @@ pub use serde_wrap::{ use crate::pl_serialize::deserialize_map_bytes; +/// Wrapper around PyObject from pyo3 with additional trait impls. #[derive(Debug)] -pub struct PythonFunction(pub PyObject); +pub struct PythonObject(pub PyObject); +// Note: We have this because the struct itself used to be called `PythonFunction`, so it's +// referred to as such from a lot of places. +pub type PythonFunction = PythonObject; -impl Clone for PythonFunction { +impl std::ops::Deref for PythonObject { + type Target = PyObject; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for PythonObject { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Clone for PythonObject { fn clone(&self) -> Self { Python::with_gil(|py| Self(self.0.clone_ref(py))) } } -impl From for PythonFunction { +impl From for PythonObject { fn from(value: PyObject) -> Self { Self(value) } } -impl Eq for PythonFunction {} +impl Eq for PythonObject {} -impl PartialEq for PythonFunction { +impl PartialEq for PythonObject { fn eq(&self, other: &Self) -> bool { Python::with_gil(|py| { let eq = self.0.getattr(py, "__eq__").unwrap(); @@ -41,7 +59,7 @@ impl PartialEq for PythonFunction { } #[cfg(feature = "serde")] -impl serde::Serialize for PythonFunction { +impl serde::Serialize for PythonObject { fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, @@ -56,20 +74,20 @@ impl serde::Serialize for PythonFunction { } #[cfg(feature = "serde")] -impl<'a> serde::Deserialize<'a> for PythonFunction { +impl<'a> serde::Deserialize<'a> for PythonObject { fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'a>, { use serde::de::Error; - deserialize_map_bytes(deserializer, &mut |bytes| { + deserialize_map_bytes(deserializer, |bytes| { Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string())) })? } } #[cfg(feature = "serde")] -impl TrySerializeToBytes for PythonFunction { +impl TrySerializeToBytes for PythonObject { fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult> { serde_wrap::serialize_pyobject_with_cloudpickle_fallback(&self.0) } @@ -85,6 +103,7 @@ mod serde_wrap { use polars_error::PolarsResult; use super::*; + use crate::config; use crate::pl_serialize::deserialize_map_bytes; pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); @@ -126,7 +145,7 @@ mod serde_wrap { { use serde::de::Error; - deserialize_map_bytes(deserializer, &mut |bytes| { + deserialize_map_bytes(deserializer, |bytes| { let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { return Err(D::Error::custom( "unexpected EOF when reading serialized pyobject version", @@ -182,23 +201,30 @@ mod serde_wrap { let dumped = pickle.call1((py_object.clone_ref(py),)); - let (dumped, used_cloudpickle) = if let Ok(v) = dumped { - (v, false) - } else { - let cloudpickle = PyModule::import(py, "cloudpickle") - .map_err(from_pyerr)? - .getattr("dumps") - .unwrap(); - let dumped = cloudpickle - .call1((py_object.clone_ref(py),)) - .map_err(from_pyerr)?; - (dumped, true) + let (dumped, used_cloudpickle) = match dumped { + Ok(v) => (v, false), + Err(e) => { + if config::verbose() { + eprintln!( + "serialize_pyobject_with_cloudpickle_fallback(): \ + retrying with cloudpickle due to error: {:?}", + e + ); + } + + let cloudpickle = PyModule::import(py, "cloudpickle")? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle.call1((py_object.clone_ref(py),))?; + (dumped, true) + }, }; - let py_bytes = dumped.extract::().map_err(from_pyerr)?; + let py_bytes = dumped.extract::()?; Ok([&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()].concat()) }) + .map_err(from_pyerr) } pub fn deserialize_pyobject_bytes_maybe_cloudpickle From>( @@ -223,9 +249,10 @@ mod serde_wrap { .getattr("loads") .unwrap(); let arg = (PyBytes::new(py, bytes),); - let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?; + let pyany_bound = pickle.call1(arg)?; Ok(PyObject::from(pyany_bound).into()) }) + .map_err(from_pyerr) } } diff --git a/py-polars/polars/_utils/logging.py b/py-polars/polars/_utils/logging.py new file mode 100644 index 000000000000..38d858b7842d --- /dev/null +++ b/py-polars/polars/_utils/logging.py @@ -0,0 +1,10 @@ +import os +import sys +from functools import partial + + +def verbose() -> bool: + return os.getenv("POLARS_VERBOSE") == "1" + + +eprint = partial(print, file=sys.stderr) diff --git a/py-polars/polars/catalog/unity/client.py b/py-polars/polars/catalog/unity/client.py index 1e3ba7085b11..818f29950919 100644 --- a/py-polars/polars/catalog/unity/client.py +++ b/py-polars/polars/catalog/unity/client.py @@ -28,6 +28,7 @@ CredentialProviderFunction, CredentialProviderFunctionReturn, ) + from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder from polars.lazyframe import LazyFrame with contextlib.suppress(ImportError): @@ -238,7 +239,7 @@ def scan_table( table_info, "scan table" ) - credential_provider, storage_options = self._init_credentials( + credential_provider, storage_options = self._init_credentials( # type: ignore[assignment] credential_provider, storage_options, table_info, @@ -367,7 +368,7 @@ def write_table( table_info, "scan table" ) - credential_provider, storage_options = self._init_credentials( + credential_provider, storage_options = self._init_credentials( # type: ignore[assignment] credential_provider, storage_options, table_info, @@ -600,15 +601,27 @@ def delete_table( def _init_credentials( self, - credential_provider: (CredentialProviderFunction | Literal["auto"] | None), + credential_provider: CredentialProviderFunction | Literal["auto"] | None, storage_options: dict[str, Any] | None, table_info: TableInfo, *, write: bool, caller_name: str, - ) -> tuple[CredentialProviderFunction | None, dict[str, Any] | None]: + ) -> tuple[ + CredentialProviderBuilder | None, + dict[str, Any] | None, + ]: + from polars.io.cloud.credential_provider._builder import ( + CredentialProviderBuilder, + ) + if credential_provider != "auto": - return credential_provider, storage_options + if credential_provider: + return CredentialProviderBuilder.from_initialized_provider( + credential_provider + ), storage_options + else: + return None, storage_options verbose = os.getenv("POLARS_VERBOSE") == "1" @@ -645,15 +658,19 @@ def _init_credentials( ) print(msg, file=sys.stderr) - return catalog_credential_provider, storage_options + return CredentialProviderBuilder.from_initialized_provider( + catalog_credential_provider + ), storage_options # This should generally not happen, but if using the temporary # credentials API fails for whatever reason, we fallback to our built-in # credential provider resolution. - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - return _maybe_init_credential_provider( + return _init_credential_provider_builder( "auto", table_info.storage_location, storage_options, caller_name ), storage_options diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index f0ab632a9e92..4c8e978cbbc0 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -2951,11 +2951,14 @@ def write_csv_to_string() -> str: elif isinstance(file, (str, os.PathLike)): file = normalize_filepath(file) - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, file, storage_options, "write_csv" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -2979,7 +2982,7 @@ def write_csv_to_string() -> str: null_value, quote_style, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) @@ -3709,15 +3712,18 @@ def write_ipc( if compression is None: compression = "uncompressed" - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = ( + credential_provider_builder = ( None if return_bytes - else _maybe_init_credential_provider( + else _init_credential_provider_builder( credential_provider, file, storage_options, "write_ipc" ) ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -3730,7 +3736,7 @@ def write_ipc( compression, compat_level, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) return file if return_bytes else None # type: ignore[return-value] @@ -3993,11 +3999,14 @@ def write_parquet( return - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, file, storage_options, "write_parquet" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -4039,7 +4048,7 @@ def write_parquet( partition_by=partition_by, partition_chunk_size_bytes=partition_chunk_size_bytes, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) @@ -4498,33 +4507,39 @@ def write_delta( else: data = self.to_arrow() - from polars.io.cloud.credential_provider import ( + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) + from polars.io.cloud.credential_provider._providers import ( _get_credentials_from_provider_expiry_aware, - _maybe_init_credential_provider, ) if not isinstance(target, DeltaTable): - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, target, storage_options, "write_delta" ) elif credential_provider is not None and credential_provider != "auto": msg = "cannot use credential_provider when passing a DeltaTable object" raise ValueError(msg) else: - credential_provider = None + credential_provider_builder = None + + del credential_provider credential_provider_creds = {} - if credential_provider is not None: + if credential_provider_builder and ( + provider := credential_provider_builder.build_credential_provider() + ): credential_provider_creds = _get_credentials_from_provider_expiry_aware( - credential_provider + provider ) # We aren't calling into polars-native write functions so we just update # the storage_options here. storage_options = ( {**(storage_options or {}), **credential_provider_creds} - if storage_options is not None or credential_provider is not None + if storage_options is not None or credential_provider_builder is not None else None ) diff --git a/py-polars/polars/io/cloud/__init__.py b/py-polars/polars/io/cloud/__init__.py index 0dfa3717184e..7a5858fcdb0f 100644 --- a/py-polars/polars/io/cloud/__init__.py +++ b/py-polars/polars/io/cloud/__init__.py @@ -1,4 +1,4 @@ -from polars.io.cloud.credential_provider import ( +from polars.io.cloud.credential_provider._providers import ( CredentialProvider, CredentialProviderAWS, CredentialProviderAzure, diff --git a/py-polars/polars/io/cloud/credential_provider/__init__.py b/py-polars/polars/io/cloud/credential_provider/__init__.py new file mode 100644 index 000000000000..7a5858fcdb0f --- /dev/null +++ b/py-polars/polars/io/cloud/credential_provider/__init__.py @@ -0,0 +1,17 @@ +from polars.io.cloud.credential_provider._providers import ( + CredentialProvider, + CredentialProviderAWS, + CredentialProviderAzure, + CredentialProviderFunction, + CredentialProviderFunctionReturn, + CredentialProviderGCP, +) + +__all__ = [ + "CredentialProvider", + "CredentialProviderAWS", + "CredentialProviderAzure", + "CredentialProviderFunction", + "CredentialProviderFunctionReturn", + "CredentialProviderGCP", +] diff --git a/py-polars/polars/io/cloud/credential_provider/_builder.py b/py-polars/polars/io/cloud/credential_provider/_builder.py new file mode 100644 index 000000000000..1ab507a5e6e0 --- /dev/null +++ b/py-polars/polars/io/cloud/credential_provider/_builder.py @@ -0,0 +1,382 @@ +from __future__ import annotations + +import abc +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Literal + +import polars._utils.logging +from polars._utils.logging import eprint, verbose +from polars._utils.unstable import issue_unstable_warning +from polars.io.cloud.credential_provider._providers import ( + CredentialProviderAWS, + CredentialProviderAzure, + CredentialProviderGCP, +) + +if TYPE_CHECKING: + from polars.io.cloud.credential_provider._providers import ( + CredentialProvider, + CredentialProviderFunction, + ) + +# https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html +OBJECT_STORE_CLIENT_OPTIONS: frozenset[str] = frozenset( + [ + "allow_http", + "allow_invalid_certificates", + "connect_timeout", + "default_content_type", + "http1_only", + "http2_only", + "http2_keep_alive_interval", + "http2_keep_alive_timeout", + "http2_keep_alive_while_idle", + "http2_max_frame_size", + "pool_idle_timeout", + "pool_max_idle_per_host", + "proxy_url", + "proxy_ca_certificate", + "proxy_excludes", + "timeout", + "user_agent", + ] +) + + +# Note: The rust-side expects this exact class name. +class CredentialProviderBuilder: + """ + Builds credential providers. + + This allows for credential provider resolution to be deferred, so that it + takes place when a query is executed, rather than during the construction of + the query. + + For example, a query plan may be constructed on a local machine and then sent + to a remote machine for execution. It is possible that only the remote machine + has sufficient permissions to access the required cloud resources. In this + case, credential provider resolution must be deferred to take place on the + remote machine (where the query is executed). + + If it instead occurs eagerly on the local machine, it may capture and use + local credentials that lack the required permissions to access the resources, + leading to a query error. + """ + + def __init__( + self, + credential_provider_init: CredentialProviderBuilderImpl, + ) -> None: + """ + Initialize configuration for building a credential provider. + + Parameters + ---------- + credential_provider_init + Initializer function that returns a credential provider. + """ + self.credential_provider_init = credential_provider_init + + # Note: The rust-side expects this exact function name. + def build_credential_provider( + self, + ) -> CredentialProvider | CredentialProviderFunction | None: + """Instantiate a credential provider from configuration.""" + verbose = polars._utils.logging.verbose() + + if verbose: + eprint( + "[CredentialProviderBuilder]: Begin initialize " + f"{self.credential_provider_init!r}" + ) + + v = self.credential_provider_init() + + if verbose: + if v is not None: + eprint( + f"[CredentialProviderBuilder]: Initialized {v!r} " + f"from {self.credential_provider_init!r}" + ) + else: + eprint( + f"[CredentialProviderBuilder]: No provider initialized " + f"from {self.credential_provider_init!r}" + ) + + return v + + @classmethod + def from_initialized_provider( + cls, credential_provider: CredentialProviderFunction + ) -> CredentialProviderBuilder: + """Initialize with an already constructed provider.""" + return cls(InitializedCredentialProvider(credential_provider)) + + def __getstate__(self) -> Any: + state = self.credential_provider_init + + if verbose(): + eprint(f"[CredentialProviderBuilder]: __getstate__(): {state = !r} ") + + return state + + def __setstate__(self, state: Any) -> None: + verbose = polars._utils.logging.verbose() + + if verbose: + eprint(f"[CredentialProviderBuilder]: __setstate__(): begin: {state = !r}") + + self.credential_provider_init = state + + if verbose: + eprint( + f"[CredentialProviderBuilder]: __setstate__(): finish: self = {self!r}" + ) + + def __repr__(self) -> str: + return f"CredentialProviderBuilder({self.credential_provider_init!r})" + + +class CredentialProviderBuilderImpl(abc.ABC): + @abc.abstractmethod + def __call__(self) -> CredentialProviderFunction | None: + pass + + @property + @abc.abstractmethod + def provider_class_name(self) -> str: + """The class name of the provider that is built by this builder.""" + + def __repr__(self) -> str: + provider_class_name = self.provider_class_name + provider_class_name = ( + "None" if provider_class_name == "NoneType" else provider_class_name + ) + builder_name = type(self).__name__ + + return f"{provider_class_name} @ {builder_name}" + + +# Wraps an already ininitialized credential provider into the builder interface. +# Used for e.g. user-provided credential providers. +class InitializedCredentialProvider(CredentialProviderBuilderImpl): + """Wraps an already initialized credential provider.""" + + def __init__(self, credential_provider: CredentialProviderFunction | None) -> None: + self.credential_provider = credential_provider + + def __call__(self) -> CredentialProviderFunction | None: + return self.credential_provider + + @property + def provider_class_name(self) -> str: + return type(self.credential_provider).__name__ + + +# Represents an automatic initialization configuration. This is created for +# credential_provider="auto". +class AutoInit(CredentialProviderBuilderImpl): + def __init__(self, cls: Any, **kw: Any) -> None: + self.cls = cls + self.kw = kw + + def __call__(self) -> Any: + # This is used for credential_provider="auto", which allows for + # ImportErrors. + try: + return self.cls(**self.kw) + except ImportError as e: + if verbose(): + eprint(f"failed to auto-initialize {self.provider_class_name}: {e!r}") + + return None + + @property + def provider_class_name(self) -> str: + return self.cls.__name__ + + +# AWS auto-init needs its own class for a bit of extra logic. +class AutoInitAWS(CredentialProviderBuilderImpl): + def __init__( + self, + initializer: Callable[[], CredentialProviderAWS], + ) -> None: + self.initializer = initializer + self.profile_name = initializer.keywords["profile_name"] # type: ignore[attr-defined] + + def __call__(self) -> CredentialProviderAWS | None: + try: + provider = self.initializer() + provider() # call it to potentially catch EmptyCredentialError + + except (ImportError, CredentialProviderAWS.EmptyCredentialError) as e: + # Check it is ImportError, EmptyCredentialError could be because the + # profile was loaded but did not contain any credentials. + if isinstance(e, ImportError) and self.profile_name: + # Hard error as we are unable to load the requested profile + # without CredentialProviderAWS (the rust-side does not load + # aws_profile). + msg = f"cannot load requested aws_profile '{self.profile_name}': {e!r}" + raise polars.exceptions.ComputeError(msg) from e + + if verbose(): + eprint(f"failed to auto-initialize {self.provider_class_name}: {e!r}") + + else: + return provider + + return None + + @property + def provider_class_name(self) -> str: + return "CredentialProviderAWS" + + +def _init_credential_provider_builder( + credential_provider: CredentialProviderFunction + | CredentialProviderBuilder + | Literal["auto"] + | None, + source: Any, + storage_options: dict[str, Any] | None, + caller_name: str, +) -> CredentialProviderBuilder | None: + def f() -> CredentialProviderBuilder | None: + # Note: The behavior of this function should depend only on the function + # parameters. Any environment-specific behavior should take place inside + # instantiated credential providers. + + from polars.io.cloud._utils import ( + _first_scan_path, + _get_path_scheme, + _is_aws_cloud, + _is_azure_cloud, + _is_gcp_cloud, + ) + + if credential_provider is None: + return None + + if isinstance(credential_provider, CredentialProviderBuilder): + # This happens when the catalog client auto-inits and passes it to + # scan/write_delta, which calls us again. + return credential_provider + + if credential_provider != "auto": + msg = f"The `credential_provider` parameter of `{caller_name}` is considered unstable." + issue_unstable_warning(msg) + + return CredentialProviderBuilder.from_initialized_provider( + credential_provider + ) + + if (path := _first_scan_path(source)) is None: + return None + + if (scheme := _get_path_scheme(path)) is None: + return None + + if _is_azure_cloud(scheme): + tenant_id = None + storage_account = None + + if storage_options is not None: + for k, v in storage_options.items(): + k = k.lower() + + # https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html + if k in { + "azure_storage_tenant_id", + "azure_storage_authority_id", + "azure_tenant_id", + "azure_authority_id", + "tenant_id", + "authority_id", + }: + tenant_id = v + elif k in {"azure_storage_account_name", "account_name"}: + storage_account = v + elif k in {"azure_use_azure_cli", "use_azure_cli"}: + continue + elif k in OBJECT_STORE_CLIENT_OPTIONS: + continue + else: + # We assume some sort of access key was given, so we + # just dispatch to the rust side. + return None + + storage_account = ( + # Prefer the one embedded in the path + CredentialProviderAzure._extract_adls_uri_storage_account(str(path)) + or storage_account + ) + + return CredentialProviderBuilder( + AutoInit( + CredentialProviderAzure, + tenant_id=tenant_id, + _storage_account=storage_account, + ) + ) + + elif _is_aws_cloud(scheme): + region = None + profile = None + default_region = None + unhandled_key = None + + if storage_options is not None: + for k, v in storage_options.items(): + k = k.lower() + + # https://docs.rs/object_store/latest/object_store/aws/enum.AmazonS3ConfigKey.html + if k in {"aws_region", "region"}: + region = v + elif k in {"aws_default_region", "default_region"}: + default_region = v + elif k in {"aws_profile", "profile"}: + profile = v + elif k in OBJECT_STORE_CLIENT_OPTIONS: + continue + else: + # We assume some sort of access key was given, so we + # just dispatch to the rust side. + unhandled_key = k + + if unhandled_key is not None: + if profile is not None: + msg = ( + "unsupported: cannot combine aws_profile with " + f"{unhandled_key} in storage_options" + ) + raise ValueError(msg) + + return None + + return CredentialProviderBuilder( + AutoInitAWS( + partial( + CredentialProviderAWS, + profile_name=profile, + region_name=region or default_region, + ) + ) + ) + + elif storage_options is not None and any( + key.lower() not in OBJECT_STORE_CLIENT_OPTIONS for key in storage_options + ): + return None + elif _is_gcp_cloud(scheme): + return CredentialProviderBuilder(AutoInit(CredentialProviderGCP)) + + return None + + credential_provider_init = f() + + if verbose(): + eprint(f"_init_credential_provider_builder(): {credential_provider_init = !r}") + + return credential_provider_init diff --git a/py-polars/polars/io/cloud/credential_provider.py b/py-polars/polars/io/cloud/credential_provider/_providers.py similarity index 63% rename from py-polars/polars/io/cloud/credential_provider.py rename to py-polars/polars/io/cloud/credential_provider/_providers.py index 18c17a4b42c3..02466e6f2f9c 100644 --- a/py-polars/polars/io/cloud/credential_provider.py +++ b/py-polars/polars/io/cloud/credential_provider/_providers.py @@ -8,9 +8,10 @@ import sys import zoneinfo from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypedDict, Union -from polars._utils.various import issue_warning +import polars._utils.logging +from polars._utils.logging import eprint, verbose if TYPE_CHECKING: if sys.version_info >= (3, 10): @@ -28,29 +29,6 @@ Callable[[], CredentialProviderFunctionReturn], "CredentialProvider" ] -# https://docs.rs/object_store/latest/object_store/enum.ClientConfigKey.html -OBJECT_STORE_CLIENT_OPTIONS: frozenset[str] = frozenset( - [ - "allow_http", - "allow_invalid_certificates", - "connect_timeout", - "default_content_type", - "http1_only", - "http2_only", - "http2_keep_alive_interval", - "http2_keep_alive_timeout", - "http2_keep_alive_while_idle", - "http2_max_frame_size", - "pool_idle_timeout", - "pool_max_idle_per_host", - "proxy_url", - "proxy_ca_certificate", - "proxy_excludes", - "timeout", - "user_agent", - ] -) - class AWSAssumeRoleKWArgs(TypedDict): """Parameters for [STS.Client.assume_role()](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role.html#STS.Client.assume_role).""" @@ -136,8 +114,8 @@ def __call__(self) -> CredentialProviderFunctionReturn: creds = session.get_credentials() if creds is None: - msg = "CredentialProviderAWS: unexpected None value returned from boto3.Session.get_credentials()" - raise ValueError(msg) + msg = "did not receive any credentials from boto3.Session.get_credentials()" + raise self.EmptyCredentialError(msg) return { "aws_access_key_id": creds.access_key, @@ -169,6 +147,14 @@ def _ensure_module_availability(cls) -> None: msg = "boto3 must be installed to use `CredentialProviderAWS`" raise ImportError(msg) + class EmptyCredentialError(Exception): + """ + Raised when boto3 returns empty credentials. + + This generally indicates that no credentials could be found in the + environment. + """ + class CredentialProviderAzure(CredentialProvider): """ @@ -230,15 +216,12 @@ def __init__( elif self._try_get_azure_storage_account_credentials_if_permitted() is None: self._ensure_module_availability() - if os.getenv("POLARS_VERBOSE") == "1": - print( - ( - "[CredentialProviderAzure]: " - f"{self.account_name = } " - f"{self.tenant_id = } " - f"{self.scopes = } " - ), - file=sys.stderr, + if verbose(): + eprint( + "[CredentialProviderAzure]: " + f"{self.account_name = } " + f"{self.tenant_id = } " + f"{self.scopes = } " ) def __call__(self) -> CredentialProviderFunctionReturn: @@ -268,14 +251,13 @@ def _try_get_azure_storage_account_credentials_if_permitted( "POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY" ) - verbose = os.getenv("POLARS_VERBOSE") == "1" + verbose = polars._utils.logging.verbose() if verbose: - print( + eprint( "[CredentialProviderAzure]: " f"{self.account_name = } " - f"{POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY = }", - file=sys.stderr, + f"{POLARS_AUTO_USE_AZURE_STORAGE_ACCOUNT_KEY = }" ) if ( @@ -290,15 +272,13 @@ def _try_get_azure_storage_account_credentials_if_permitted( } if verbose: - print( - "[CredentialProviderAzure]: Retrieved account key from Azure CLI", - file=sys.stderr, + eprint( + "[CredentialProviderAzure]: Retrieved account key from Azure CLI" ) except Exception as e: if verbose: - print( - f"[CredentialProviderAzure]: Could not retrieve account key from Azure CLI: {e}", - file=sys.stderr, + eprint( + f"[CredentialProviderAzure]: Could not retrieve account key from Azure CLI: {e}" ) else: return creds, None @@ -449,170 +429,6 @@ def _ensure_module_availability(cls) -> None: raise ImportError(msg) -def _maybe_init_credential_provider( - credential_provider: CredentialProviderFunction | Literal["auto"] | None, - source: Any, - storage_options: dict[str, Any] | None, - caller_name: str, -) -> CredentialProviderFunction | CredentialProvider | None: - from polars.io.cloud._utils import ( - _first_scan_path, - _get_path_scheme, - _is_aws_cloud, - _is_azure_cloud, - _is_gcp_cloud, - ) - - if credential_provider is not None: - msg = f"The `credential_provider` parameter of `{caller_name}` is considered unstable." - issue_unstable_warning(msg) - - if credential_provider != "auto": - return credential_provider - - verbose = os.getenv("POLARS_VERBOSE") == "1" - - if (path := _first_scan_path(source)) is None: - return None - - if (scheme := _get_path_scheme(path)) is None: - return None - - provider: ( - CredentialProviderAWS | CredentialProviderAzure | CredentialProviderGCP | None - ) = None - - try: - # For Azure we dispatch to `azure.identity` as much as possible - if _is_azure_cloud(scheme): - tenant_id = None - storage_account = None - - if storage_options is not None: - for k, v in storage_options.items(): - k = k.lower() - - # https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html - if k in { - "azure_storage_tenant_id", - "azure_storage_authority_id", - "azure_tenant_id", - "azure_authority_id", - "tenant_id", - "authority_id", - }: - tenant_id = v - elif k in {"azure_storage_account_name", "account_name"}: - storage_account = v - elif k in {"azure_use_azure_cli", "use_azure_cli"}: - continue - elif k in OBJECT_STORE_CLIENT_OPTIONS: - continue - else: - # We assume some sort of access key was given, so we - # just dispatch to the rust side. - return None - - storage_account = ( - # Prefer the one embedded in the path - CredentialProviderAzure._extract_adls_uri_storage_account(str(path)) - or storage_account - ) - - provider = CredentialProviderAzure( - tenant_id=tenant_id, - _storage_account=storage_account, - ) - elif _is_aws_cloud(scheme): - region = None - profile = None - default_region = None - unhandled_key = None - - if storage_options is not None: - for k, v in storage_options.items(): - k = k.lower() - - # https://docs.rs/object_store/latest/object_store/aws/enum.AmazonS3ConfigKey.html - if k in {"aws_region", "region"}: - region = v - elif k in {"aws_default_region", "default_region"}: - default_region = v - elif k in {"aws_profile", "profile"}: - profile = v - elif k in OBJECT_STORE_CLIENT_OPTIONS: - continue - else: - # We assume some sort of access key was given, so we - # just dispatch to the rust side. - unhandled_key = k - - to_silence_this_warning = ( - "To silence this warning, pass 'aws_profile': None in storage_options." - ) - - if unhandled_key is not None: - if profile is not None: - msg = ( - f"the configured AWS profile '{profile}' may be ignored " - "as it is not compatible with the provided " - f"storage_option key '{unhandled_key}'. " - f"{to_silence_this_warning}" - ) - issue_warning(msg, UserWarning) - - return None - - try: - provider = CredentialProviderAWS( - profile_name=profile, region_name=region or default_region - ) - except ImportError: - if profile is not None: - msg = ( - f"the configured AWS profile '{profile}' may not " - "be used as boto3 is not installed. " - f"{to_silence_this_warning}" - ) - # Conservatively warn instead of hard error. It could just be - # set as a default environment flag. - issue_warning(msg, UserWarning) - # Note: Enclosing scope will catch ImportErrors - raise - - elif storage_options is not None and any( - key.lower() not in OBJECT_STORE_CLIENT_OPTIONS for key in storage_options - ): - return None - elif _is_gcp_cloud(scheme): - provider = CredentialProviderGCP() - - except ImportError as e: - if verbose: - msg = f"unable to auto-select credential provider: {e!r}" - print(msg, file=sys.stderr) - - if provider is not None: - # CredentialProviderAWS raises an error in some cases when - # `get_credentials()` returns None (e.g. the environment may not - # have / require credentials). We check this here and avoid - # using it if that is the case. - try: - provider() - except Exception as e: - provider = None - - if verbose: - msg = f"unable to auto-select credential provider: {e!r}" - print(msg, file=sys.stderr) - - if provider is not None and verbose: - msg = f"auto-selected credential provider: {type(provider).__name__}" - print(msg, file=sys.stderr) - - return provider - - def _get_credentials_from_provider_expiry_aware( credential_provider: CredentialProviderFunction, ) -> dict[str, str]: @@ -622,15 +438,10 @@ def _get_credentials_from_provider_expiry_aware( opt_expiry is not None and (expires_in := opt_expiry - int(datetime.now().timestamp())) < 7 ): - import os - import sys from time import sleep - if os.getenv("POLARS_VERBOSE") == "1": - print( - f"waiting for {expires_in} seconds for refreshed credentials", - file=sys.stderr, - ) + if verbose(): + eprint(f"waiting for {expires_in} seconds for refreshed credentials") sleep(1 + expires_in) creds, _ = credential_provider() diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 67fcc78682f3..0f1d32d668dd 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -24,7 +24,9 @@ parse_row_index_args, prepare_file_arg, ) -from polars.io.cloud.credential_provider import _maybe_init_credential_provider +from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, +) from polars.io.csv._utils import _check_arg_is_1byte, _update_columns from polars.io.csv.batched_reader import BatchedCsvReader @@ -37,6 +39,7 @@ from polars import DataFrame, LazyFrame from polars._typing import CsvEncoding, PolarsDataType, SchemaDict from polars.io.cloud import CredentialProviderFunction + from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder @deprecate_renamed_parameter("dtypes", "schema_overrides", version="0.20.31") @@ -1311,9 +1314,10 @@ def with_column_names(cols: list[str]) -> list[str]: if not infer_schema: infer_schema_length = 0 - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "scan_csv" ) + del credential_provider return _scan_csv_impl( source, @@ -1346,7 +1350,7 @@ def with_column_names(cols: list[str]) -> list[str]: glob=glob, retries=retries, storage_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, file_cache_ttl=file_cache_ttl, include_file_paths=include_file_paths, ) @@ -1391,7 +1395,7 @@ def _scan_csv_impl( decimal_comma: bool = False, glob: bool = True, storage_options: dict[str, Any] | None = None, - credential_provider: CredentialProviderFunction | None = None, + credential_provider: CredentialProviderBuilder | None = None, retries: int = 2, file_cache_ttl: int | None = None, include_file_paths: str | None = None, diff --git a/py-polars/polars/io/delta.py b/py-polars/polars/io/delta.py index 4ac7a704fd6e..d7600d09f2fc 100644 --- a/py-polars/polars/io/delta.py +++ b/py-polars/polars/io/delta.py @@ -296,24 +296,30 @@ def scan_delta( from deltalake import DeltaTable - from polars.io.cloud.credential_provider import ( + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) + from polars.io.cloud.credential_provider._providers import ( _get_credentials_from_provider_expiry_aware, - _maybe_init_credential_provider, ) if not isinstance(source, DeltaTable): - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "scan_delta" ) elif credential_provider is not None and credential_provider != "auto": msg = "cannot use credential_provider when passing a DeltaTable object" raise ValueError(msg) else: - credential_provider = None + credential_provider_builder = None + + del credential_provider - if credential_provider is not None: + if credential_provider_builder and ( + provider := credential_provider_builder.build_credential_provider() + ): credential_provider_creds = _get_credentials_from_provider_expiry_aware( - credential_provider + provider ) dl_tbl = _get_delta_lake_table( @@ -321,7 +327,7 @@ def scan_delta( version=version, storage_options=( {**(storage_options or {}), **credential_provider_creds} - if storage_options is not None or credential_provider is not None + if storage_options is not None or credential_provider_builder is not None else None ), delta_table_options=delta_table_options, @@ -404,7 +410,7 @@ def _split_schema( allow_missing_columns=True, hive_partitioning=len(partition_columns) > 0, storage_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, # type: ignore[arg-type] rechunk=rechunk or False, ) diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 984ec54d966f..c933728810c9 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -22,7 +22,9 @@ parse_row_index_args, prepare_file_arg, ) -from polars.io.cloud.credential_provider import _maybe_init_credential_provider +from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, +) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -460,9 +462,10 @@ def scan_ipc( # Memory Mapping is now a no-op _ = memory_map - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "scan_parquet" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -478,7 +481,7 @@ def scan_ipc( rechunk, parse_row_index_args(row_index_name, row_index_offset), cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, file_cache_ttl=file_cache_ttl, hive_partitioning=hive_partitioning, diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index 445889712244..a99e4eb4e1f5 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -11,7 +11,9 @@ from polars._utils.wrap import wrap_df, wrap_ldf from polars.datatypes import N_INFER_DEFAULT from polars.io._utils import parse_row_index_args -from polars.io.cloud.credential_provider import _maybe_init_credential_provider +from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, +) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PyLazyFrame @@ -157,10 +159,12 @@ def read_ndjson( return df - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "read_ndjson" ) + del credential_provider + return scan_ndjson( source, schema=schema, @@ -176,7 +180,7 @@ def read_ndjson( include_file_paths=include_file_paths, retries=retries, storage_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, # type: ignore[arg-type] file_cache_ttl=file_cache_ttl, ).collect() @@ -300,10 +304,12 @@ def scan_ndjson( msg = "'infer_schema_length' should be positive" raise ValueError(msg) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "scan_ndjson" ) + del credential_provider + if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] else: @@ -325,7 +331,7 @@ def scan_ndjson( include_file_paths=include_file_paths, retries=retries, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, file_cache_ttl=file_cache_ttl, ) return wrap_ldf(pylf) diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index 7080738e2e84..e0a0e6ccb951 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -21,7 +21,9 @@ parse_row_index_args, prepare_file_arg, ) -from polars.io.cloud.credential_provider import _maybe_init_credential_provider +from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, +) with contextlib.suppress(ImportError): from polars.polars import PyLazyFrame @@ -33,6 +35,7 @@ from polars import DataFrame, DataType, LazyFrame from polars._typing import FileSource, ParallelStrategy, SchemaDict from polars.io.cloud import CredentialProviderFunction + from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @@ -482,7 +485,7 @@ def scan_parquet( normalize_filepath(source, check_not_directory=False) for source in source ] - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, source, storage_options, "scan_parquet" ) @@ -495,7 +498,7 @@ def scan_parquet( row_index_name=row_index_name, row_index_offset=row_index_offset, storage_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, low_memory=low_memory, use_statistics=use_statistics, hive_partitioning=hive_partitioning, @@ -519,7 +522,7 @@ def _scan_parquet_impl( row_index_name: str | None = None, row_index_offset: int = 0, storage_options: dict[str, object] | None = None, - credential_provider: CredentialProviderFunction | None = None, + credential_provider: CredentialProviderBuilder | None = None, low_memory: bool = False, use_statistics: bool = True, hive_partitioning: bool | None = None, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 3abfeb6c95b2..a39c246965f8 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2438,11 +2438,14 @@ def sink_parquet( "null_count": True, } - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, path, storage_options, "sink_parquet" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -2459,7 +2462,7 @@ def sink_parquet( data_page_size=data_page_size, maintain_order=maintain_order, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) @@ -2562,11 +2565,14 @@ def sink_ipc( no_optimization=no_optimization, ) - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, path, storage_options, "sink_ipc" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -2579,7 +2585,7 @@ def sink_ipc( compression=compression, maintain_order=maintain_order, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) @@ -2749,11 +2755,14 @@ def sink_csv( no_optimization=no_optimization, ) - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, path, storage_options, "sink_csv" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -2778,7 +2787,7 @@ def sink_csv( quote_style=quote_style, maintain_order=maintain_order, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) @@ -2877,11 +2886,14 @@ def sink_ndjson( no_optimization=no_optimization, ) - from polars.io.cloud.credential_provider import _maybe_init_credential_provider + from polars.io.cloud.credential_provider._builder import ( + _init_credential_provider_builder, + ) - credential_provider = _maybe_init_credential_provider( + credential_provider_builder = _init_credential_provider_builder( credential_provider, path, storage_options, "sink_ndjson" ) + del credential_provider if storage_options: storage_options = list(storage_options.items()) # type: ignore[assignment] @@ -2893,7 +2905,7 @@ def sink_ndjson( path=path, maintain_order=maintain_order, cloud_options=storage_options, - credential_provider=credential_provider, + credential_provider=credential_provider_builder, retries=retries, ) diff --git a/py-polars/tests/unit/io/cloud/test_credential_provider.py b/py-polars/tests/unit/io/cloud/test_credential_provider.py index 8cc39ce450c8..eec36cf1c08d 100644 --- a/py-polars/tests/unit/io/cloud/test_credential_provider.py +++ b/py-polars/tests/unit/io/cloud/test_credential_provider.py @@ -1,9 +1,11 @@ import io +import pickle from typing import Any import pytest import polars as pl +import polars.io.cloud.credential_provider from polars.exceptions import ComputeError @@ -16,7 +18,7 @@ pl.scan_ipc, ], ) -def test_scan_credential_provider( +def test_credential_provider_scan( io_func: Any, monkeypatch: pytest.MonkeyPatch ) -> None: err_magic = "err_magic_3" @@ -24,7 +26,9 @@ def test_scan_credential_provider( def raises(*_: None, **__: None) -> None: raise AssertionError(err_magic) - monkeypatch.setattr(pl.CredentialProviderAWS, "__init__", raises) + from polars.io.cloud.credential_provider._builder import CredentialProviderBuilder + + monkeypatch.setattr(CredentialProviderBuilder, "__init__", raises) with pytest.raises(AssertionError, match=err_magic): io_func("s3://bucket/path", credential_provider="auto") @@ -55,13 +59,52 @@ def raises(*_: None, **__: None) -> None: def raises_2() -> pl.CredentialProviderFunctionReturn: raise AssertionError(err_magic) - # Note to reader: It is converted to a ComputeError as it is being called - # from Rust. - with pytest.raises(ComputeError, match=err_magic): + with pytest.raises(AssertionError, match=err_magic): io_func("s3://bucket/path", credential_provider=raises_2).collect() -def test_scan_credential_provider_serialization() -> None: +@pytest.mark.parametrize( + ("provider_class", "path"), + [ + (polars.io.cloud.credential_provider.CredentialProviderAWS, "s3://.../..."), + (polars.io.cloud.credential_provider.CredentialProviderGCP, "gs://.../..."), + (polars.io.cloud.credential_provider.CredentialProviderAzure, "az://.../..."), + ], +) +def test_credential_provider_serialization_auto_init( + provider_class: polars.io.cloud.credential_provider.CredentialProvider, + path: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def raises_1(*a: Any, **kw: Any) -> None: + msg = "err_magic_1" + raise AssertionError(msg) + + monkeypatch.setattr(provider_class, "__init__", raises_1) + + # Credential provider should not be initialized during query plan construction. + q = pl.scan_parquet(path) + + # Check baseline - query plan is configured to auto-initialize the credential + # provider. + with pytest.raises(pl.exceptions.ComputeError, match="err_magic_1"): + q.collect() + + q = pickle.loads(pickle.dumps(q)) + + def raises_2(*a: Any, **kw: Any) -> None: + msg = "err_magic_2" + raise AssertionError(msg) + + monkeypatch.setattr(provider_class, "__init__", raises_2) + + # Check that auto-initialization happens upon executing the deserialized + # query. + with pytest.raises(pl.exceptions.ComputeError, match="err_magic_2"): + q.collect() + + +def test_credential_provider_serialization_custom_provider() -> None: err_magic = "err_magic_3" class ErrCredentialProvider(pl.CredentialProvider): @@ -80,7 +123,7 @@ def __call__(self) -> pl.CredentialProviderFunctionReturn: lf.collect() -def test_credential_provider_skips_config_autoload( +def test_credential_provider_skips_google_config_autoload( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv("GOOGLE_SERVICE_ACCOUNT_PATH", "__non_existent")