Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions datafusion/functions/src/core/arrow_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit};
use arrow::error::ArrowError;
use datafusion_common::{
Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt,
exec_datafusion_err, exec_err, internal_err, types::logical_string,
utils::take_function_args,
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err,
types::logical_string, utils::take_function_args,
};

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
Expand Down Expand Up @@ -140,10 +140,9 @@ impl ScalarUDFImpl for ArrowCastFunc {
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()),
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
|casted_type| {
let data_type = parse_arrow_cast_data_type(casted_type)?;
Ok(Field::new(self.name(), data_type, nullable).into())
},
)
}
Expand Down Expand Up @@ -189,10 +188,35 @@ pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result<Dat
);
};

val.parse().map_err(|e| match e {
// If the data type cannot be parsed, return a Plan error to signal an
// error in the input rather than a more general ArrowError
parse_arrow_cast_data_type(val)
}

pub(crate) fn parse_arrow_cast_data_type(casted_type: &str) -> Result<DataType> {
let data_type = casted_type.parse().map_err(|e| match e {
ArrowError::ParseError(e) => exec_datafusion_err!("{e}"),
e => arrow_datafusion_err!(e),
})
})?;

validate_arrow_cast_data_type(&data_type)?;
Ok(data_type)
}

fn validate_arrow_cast_data_type(data_type: &DataType) -> Result<()> {
match data_type {
DataType::Time32(unit @ (TimeUnit::Microsecond | TimeUnit::Nanosecond)) => {
Err(plan_datafusion_err!(
"Invalid Arrow type combination: Time32 only supports Second and Millisecond, got {:?}. Use Time64({:?}) for sub-millisecond precision",
unit,
unit
))
}
DataType::Time64(unit @ (TimeUnit::Second | TimeUnit::Millisecond)) => {
Err(plan_datafusion_err!(
"Invalid Arrow type combination: Time64 only supports Microsecond and Nanosecond, got {:?}. Use Time32({:?}) for second or millisecond precision",
unit,
unit
))
}
_ => Ok(()),
}
}
16 changes: 6 additions & 10 deletions datafusion/functions/src/core/arrow_try_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast`

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
use datafusion_common::{
Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err,
internal_err, types::logical_string, utils::take_function_args,
Result, datatype::DataTypeExt, exec_err, internal_err, types::logical_string,
utils::take_function_args,
};

use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
Expand All @@ -31,7 +30,7 @@ use datafusion_expr::{
};
use datafusion_macros::user_doc;

use super::arrow_cast::data_type_from_type_arg;
use super::arrow_cast::{data_type_from_type_arg, parse_arrow_cast_data_type};

/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring.
///
Expand Down Expand Up @@ -111,12 +110,9 @@ impl ScalarUDFImpl for ArrowTryCastFunc {
self.name()
)
},
|casted_type| match casted_type.parse::<DataType>() {
Ok(data_type) => {
Ok(Field::new(self.name(), data_type, true).into())
}
Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")),
Err(e) => Err(arrow_datafusion_err!(e)),
|casted_type| {
let data_type = parse_arrow_cast_data_type(casted_type)?;
Ok(Field::new(self.name(), data_type, true).into())
},
)
}
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ SELECT arrow_cast('1', arrow_cast('Utf8', 'Utf8'))
query error DataFusion error: Execution error: Unsupported type 'unknown'\. Must be a supported arrow type name such as 'Int32' or 'Timestamp\(ns\)'\. Error unknown token: unknown
SELECT arrow_cast('1', 'unknown')

query error Invalid Arrow type combination: Time32 only supports Second and Millisecond, got Microsecond
SELECT arrow_cast(0, 'Time32(Microsecond)') + 1

query error Invalid Arrow type combination: Time32 only supports Second and Millisecond, got Nanosecond
SELECT arrow_cast(0, 'Time32(Nanosecond)') + 1

query error Invalid Arrow type combination: Time64 only supports Microsecond and Nanosecond, got Second
SELECT arrow_cast(0, 'Time64(Second)') + 1

query error Invalid Arrow type combination: Time64 only supports Microsecond and Nanosecond, got Millisecond
SELECT arrow_cast(0, 'Time64(Millisecond)') + 1

# Round Trip tests:
query TTTTTTTTTTTTTTTTTTTTTTTTT
SELECT
Expand Down
Loading