diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 5c409ef0ddd..eb9c8fa2523 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -292,7 +292,7 @@ mod tests { use spacetimedb_lib::error::ResultTest; use spacetimedb_lib::sats::time_duration::TimeDuration; use spacetimedb_lib::sats::timestamp::Timestamp; - use spacetimedb_lib::sats::{product, GroundSpacetimeType, ProductType}; + use spacetimedb_lib::sats::{product, ArrayValue, GroundSpacetimeType, ProductType}; use spacetimedb_lib::{AlgebraicType, AlgebraicValue, ConnectionId, Identity, Uuid}; fn make_row(row: &[AlgebraicValue]) -> Result, serde_json::Error> { @@ -512,6 +512,48 @@ Roundtrip time: 1.00ms"#, assert_eq!(expected, table); } + #[test] + fn output_arrays() -> ResultTest<()> { + let kind: ProductType = [ + ("ints", AlgebraicType::array(AlgebraicType::I32)), + ("strings", AlgebraicType::array(AlgebraicType::String)), + ("nested", AlgebraicType::array(AlgebraicType::array(AlgebraicType::I32))), + ("bytes", AlgebraicType::bytes()), + ] + .into(); + + let value = product![ + AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())), + AlgebraicValue::Array(ArrayValue::String(["one".into(), "two".into()].into())), + AlgebraicValue::Array(ArrayValue::Array( + [ArrayValue::I32([1, 2].into()), ArrayValue::I32([3, 4].into())].into() + )), + AlgebraicValue::Bytes([0xde, 0xad].into()), + ]; + + expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + ints | strings | nested | bytes +-----------+----------------+------------------+-------- + [1, 2, 3] | ["one", "two"] | [[1, 2], [3, 4]] | 0xdead"#, + ); + + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value], + r#" + ints | strings | nested | bytes +-----------+----------------+------------------+---------- + {1, 2, 3} | {"one", "two"} | {{1, 2}, {3, 4}} | "0xdead""#, + ); + + Ok(()) + } + // Verify the output of `sql` matches the inputs that return true for [`AlgebraicType::is_special()`] #[test] fn output_special_types() -> ResultTest<()> { diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs index 01fda799920..fb120e3583f 100644 --- a/crates/pg/src/encoder.rs +++ b/crates/pg/src/encoder.rs @@ -2,8 +2,8 @@ use crate::pg_server::PgError; use pgwire::api::portal::Format; use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::api::Type; -use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter}; -use spacetimedb_lib::sats::{satn, ValueWithType}; +use spacetimedb_lib::sats::satn::{PsqlChars, PsqlClient, PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::{satn, ArrayValue, ValueWithType}; use spacetimedb_lib::{ ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, Uuid, }; @@ -54,7 +54,11 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => Type::NUMERIC_ARRAY, - _ => Type::ANYARRAY, + AlgebraicType::F32 => Type::FLOAT4_ARRAY, + AlgebraicType::F64 => Type::FLOAT8_ARRAY, + AlgebraicType::Ref(_) | AlgebraicType::Sum(_) | AlgebraicType::Product(_) | AlgebraicType::Array(_) => { + Type::JSON_ARRAY + } }, AlgebraicType::Product(_) => match format { PsqlPrintFmt::Hex => Type::BYTEA, @@ -155,7 +159,9 @@ impl TypedWriter for PsqlFormatter<'_> { return Ok(()); } - let PsqlChars { start, sep, end, quote } = ty.client.format_chars(); + let PsqlChars { + start, sep, end, quote, .. + } = ty.client.format_chars(); let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); let json = format!( "{start}{quote}{name}{quote}{sep} {}{end}", @@ -164,6 +170,78 @@ impl TypedWriter for PsqlFormatter<'_> { self.encoder.encode_field(&json)?; Ok(()) } + + fn write_array( + &mut self, + value: &ValueWithType<'_, ArrayValue>, + psql: &PsqlType, + ty: &AlgebraicType, + ) -> Result { + // `array` is a byte array in SQL output, so keep the existing bytea path. + if *ty == AlgebraicType::U8 { + return Ok(false); + } + + fn collect(arr: &[I], map: F) -> Vec + where + F: FnMut(&I) -> O, + { + arr.iter().map(map).collect() + } + + let complex_value = |elem: AlgebraicValue, elem_ty: &AlgebraicType, client| { + let tuple = ProductType::from([elem_ty.clone()]); + let psql_ty = PsqlType { + client, + tuple: &tuple, + field: &tuple.elements[0], + idx: 0, + }; + satn::PsqlWrapper { + ty: psql_ty, + value: value.with(elem_ty, &elem), + } + .to_string() + }; + + match value.value() { + ArrayValue::Bool(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::I8(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U8(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::I16(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U16(arr) => self.encoder.encode_field(&collect(arr, |v| i32::from(*v)))?, + ArrayValue::I32(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U32(arr) => self.encoder.encode_field(&collect(arr, |v| i64::from(*v)))?, + ArrayValue::I64(arr) => self.encoder.encode_field(&arr.as_ref())?, + ArrayValue::U64(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::I128(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::U128(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::I256(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::U256(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::F32(arr) => self.encoder.encode_field(&collect(arr, |v| *v.as_ref()))?, + ArrayValue::F64(arr) => self.encoder.encode_field(&collect(arr, |v| *v.as_ref()))?, + ArrayValue::String(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?, + ArrayValue::Array(arr) => { + // Nested arrays are exposed as JSON arrays for the PostgreSQL wire protocol. + let values = collect(arr, |v| { + complex_value(AlgebraicValue::Array(v.clone()), ty, PsqlClient::SpacetimeDB) + }); + self.encoder.encode_field(&values)?; + } + ArrayValue::Sum(arr) => { + let values = collect(arr, |v| complex_value(AlgebraicValue::Sum(v.clone()), ty, psql.client)); + self.encoder.encode_field(&values)?; + } + ArrayValue::Product(arr) => { + let values = collect(arr, |v| { + complex_value(AlgebraicValue::Product(v.clone()), ty, psql.client) + }); + self.encoder.encode_field(&values)?; + } + } + + Ok(true) + } } #[cfg(test)] @@ -173,7 +251,7 @@ mod tests { use futures::StreamExt; use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_lib::sats::algebraic_value::Packed; - use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant}; + use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ArrayValue, ProductType, SumTypeVariant}; use spacetimedb_lib::{ConnectionId, Identity}; async fn run(schema: ProductType, row: ProductValue) -> String { @@ -236,6 +314,30 @@ mod tests { assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t"); } + #[tokio::test] + async fn test_array() { + let schema = ProductType::from([ + AlgebraicType::array(AlgebraicType::I32), + AlgebraicType::array(AlgebraicType::String), + AlgebraicType::array(AlgebraicType::array(AlgebraicType::I32)), + AlgebraicType::bytes(), + ]); + let value = product![ + AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())), + AlgebraicValue::Array(ArrayValue::String(["one".into(), "two".into()].into())), + AlgebraicValue::Array(ArrayValue::Array( + [ArrayValue::I32([1, 2].into()), ArrayValue::I32([3, 4].into())].into() + )), + AlgebraicValue::Bytes([0xde, 0xad].into()), + ]; + + let row = run(schema, value).await; + assert_eq!( + row, + "\0\0\0\u{7}{1,2,3}\0\0\0\t{one,two}\0\0\0\u{13}{\"[1, 2]\",\"[3, 4]\"}\0\0\0\u{6}\\xdead" + ); + } + #[tokio::test] async fn test_enum() { let some = AlgebraicType::option(AlgebraicType::I64); diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index afaaf421264..75bbba558e3 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,7 +1,7 @@ use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; use crate::uuid::Uuid; -use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; +use crate::{i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; @@ -453,8 +453,10 @@ pub enum PsqlClient { pub struct PsqlChars { pub start: char, + pub start_array: &'static str, pub sep: &'static str, pub end: char, + pub end_array: &'static str, pub quote: &'static str, } @@ -463,14 +465,18 @@ impl PsqlClient { match self { PsqlClient::SpacetimeDB => PsqlChars { start: '(', + start_array: "[", sep: " =", end: ')', + end_array: "]", quote: "", }, PsqlClient::Postgres => PsqlChars { start: '{', + start_array: "{", sep: ":", end: '}', + end_array: "}", quote: "\"", }, } @@ -588,6 +594,17 @@ pub trait TypedWriter { Ok(false) } + /// Writes an array as a single value. Returns `false` to use the default + /// typed serialization path instead. + fn write_array( + &mut self, + _value: &ValueWithType<'_, ArrayValue>, + _psql: &PsqlType, + _ty: &AlgebraicType, + ) -> Result { + Ok(false) + } + fn write_record( &mut self, fields: Vec<(Cow, PsqlType, ValueWithType)>, @@ -764,6 +781,39 @@ impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { Ok(TypedArrayFormatter { ty: self.ty, f: self.f }) } + fn serialize_array_raw(self, value: &ValueWithType<'_, ArrayValue>) -> Result { + let mut ty = &*value.ty().elem_ty; + while let AlgebraicType::Ref(r) = ty { + ty = &value.typespace()[*r]; + } + if self.f.write_array(value, self.ty, ty)? { + return Ok(()); + } + match (value.value(), ty) { + (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Product(v), AlgebraicType::Product(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(self), + (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(self), + (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(self), + (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(self), + (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(self), + (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(self), + (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(self), + (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(self), + (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(self), + (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(self), + (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(self), + (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(self), + (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(self), + (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(self), + (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(self), + (ArrayValue::String(v), AlgebraicType::String) => v.serialize(self), + (ArrayValue::Array(v), AlgebraicType::Array(ty)) => value.with(ty, v).serialize(self), + (val, _) if val.is_empty() => ser::SerializeArray::end(self.serialize_array(0)?), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), + } + } + fn serialize_seq_product(self, _len: usize) -> Result { Ok(TypedSeqFormatter { ty: self.ty, f: self.f }) } @@ -893,11 +943,53 @@ impl TypedWriter for SqlFormatter<'_, '_> { write!(self.fmt, "\"{value}\"") } + fn write_array( + &mut self, + value: &ValueWithType<'_, ArrayValue>, + _psql: &PsqlType, + ty: &AlgebraicType, + ) -> Result { + // `array` is rendered as bytes in SQL output. + if *ty == AlgebraicType::U8 { + return Ok(false); + } + + let PsqlChars { + start_array, end_array, .. + } = self.ty.client.format_chars(); + write!(self.fmt, "{start_array}")?; + let tuple = ProductType::from([ty.clone()]); + let field = &tuple.elements[0]; + for (idx, elem) in value.value().iter_cloned().enumerate() { + if idx > 0 { + write!(self.fmt, ", ")?; + } + let psql_ty = PsqlType { + client: self.ty.client, + tuple: &tuple, + field, + idx: 0, + }; + write!( + self.fmt, + "{}", + PsqlWrapper { + ty: psql_ty, + value: value.with(ty, &elem) + } + )?; + } + write!(self.fmt, "{end_array}")?; + Ok(true) + } + fn write_record( &mut self, fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, ) -> Result<(), Self::Error> { - let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars(); + let PsqlChars { + start, sep, end, quote, .. + } = self.ty.client.format_chars(); write!(self.fmt, "{start}")?; for (idx, (name, ty, value)) in fields.into_iter().enumerate() { if idx > 0 { diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index 229d1d7c0d8..063c30cb380 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -6,8 +6,10 @@ mod impls; pub mod serde; use crate::de::DeserializeSeed; -use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType}; -use crate::{AlgebraicValue, WithTypespace}; +use crate::{ + algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ArrayValue, ProductValue, SumValue, ValueWithType, +}; +use crate::{AlgebraicType, AlgebraicValue, WithTypespace}; use core::marker::PhantomData; use core::{convert::Infallible, fmt}; use ethnum::{i256, u256}; @@ -142,6 +144,44 @@ pub trait Serializer: Sized { self.serialize_variant(tag, var_ty.name().map(|n| &**n), &sum.with(&var_ty.algebraic_type, val)) } + /// Serialize an array value with its static element type. + /// + /// Allow to override the default serialization for formats that can encode + /// array values directly rather than streaming each element independently. + fn serialize_array_raw(self, value: &ValueWithType<'_, ArrayValue>) -> Result { + let mut ty = &*value.ty().elem_ty; + loop { + // We're doing this because of `Ref`s. + break match (value.value(), ty) { + (_, &AlgebraicType::Ref(r)) => { + ty = &value.typespace()[r]; + continue; + } + (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Product(v), AlgebraicType::Product(ty)) => value.with(ty, v).serialize(self), + (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(self), + (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(self), + (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(self), + (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(self), + (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(self), + (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(self), + (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(self), + (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(self), + (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(self), + (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(self), + (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(self), + (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(self), + (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(self), + (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(self), + (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(self), + (ArrayValue::String(v), AlgebraicType::String) => v.serialize(self), + (ArrayValue::Array(v), AlgebraicType::Array(ty)) => value.with(ty, v).serialize(self), + (val, _) if val.is_empty() => self.serialize_array(0)?.end(), + (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), + }; + } + } + /// Serialize a sum value provided the chosen `tag`, `name`, and `value`. fn serialize_variant( self, diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index a4b008bad48..8d016e538d8 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -227,36 +227,7 @@ impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { ser.serialize_named_product_raw(self) }); impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => { - let mut ty = &*self.ty().elem_ty; - loop { // We're doing this because of `Ref`s. - break match (self.value(), ty) { - (_, &AlgebraicType::Ref(r)) => { - ty = &self.typespace()[r]; - continue; - } - (ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => self.with(ty, v).serialize(ser), - (ArrayValue::Product(v), AlgebraicType::Product(ty)) => self.with(ty, v).serialize(ser), - (ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(ser), - (ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(ser), - (ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(ser), - (ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(ser), - (ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(ser), - (ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(ser), - (ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(ser), - (ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(ser), - (ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(ser), - (ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(ser), - (ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(ser), - (ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(ser), - (ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(ser), - (ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(ser), - (ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(ser), - (ArrayValue::String(v), AlgebraicType::String) => v.serialize(ser), - (ArrayValue::Array(v), AlgebraicType::Array(ty)) => self.with(ty, v).serialize(ser), - (val, _) if val.is_empty() => ser.serialize_array(0)?.end(), - (val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"), - } - } + ser.serialize_array_raw(self) }); impl_serialize!([] spacetimedb_primitives::ArgId, (self, ser) => ser.serialize_u64(self.0));