diff --git a/Makefile b/Makefile index 74eb318..4445bf3 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,8 @@ BLAS_SRCS = \ $(BLAS_SRC_DIR_L1)/ddot.c \ $(BLAS_SRC_DIR_L1)/hdot.c \ $(BLAS_SRC_DIR_L1)/qdot.c \ + $(BLAS_SRC_DIR_L1)/sdsdot.c \ + $(BLAS_SRC_DIR_L1)/hsdot.c \ $(BLAS_SRC_DIR_L1)/snrm2.c \ $(BLAS_SRC_DIR_L1)/dnrm2.c \ $(BLAS_SRC_DIR_L1)/hnrm2.c \ diff --git a/include/softblas.h b/include/softblas.h index 9e2dbb4..3a51f7b 100644 --- a/include/softblas.h +++ b/include/softblas.h @@ -220,6 +220,8 @@ float128_t qasum(uint64_t N, const float128_t *QX, uint64_t incX, const uint_fas void qaxpy(uint64_t N, float128_t QA, float128_t *QX, int64_t incX, float128_t *QY, int64_t incY, const uint_fast8_t rndMode); void qcopy(uint64_t N, const float128_t *QX, int64_t incX, float128_t *QY, int64_t incY, const uint_fast8_t rndMode); float128_t qdot(const uint64_t N, const float128_t *X, const int64_t incX, const float128_t *Y, const int64_t incY, const uint_fast8_t rndMode); +float32_t sdsdot(const uint64_t N, const float32_t alpha, const float32_t *SX, const int64_t incX, const float32_t *SY, const int64_t incY, const uint_fast8_t rndMode); +float32_t hsdot(const uint64_t N, const float16_t alpha, const float16_t *X, const int64_t incX, const float16_t *Y, const int64_t incY, const uint_fast8_t rndMode); float128_t qnrm2(uint64_t N, const float128_t *X, uint64_t incX, const uint_fast8_t rndMode); void qrot(const uint64_t N, float16_t *X, const uint64_t incX, float16_t *Y, const uint64_t incY, const float16_t c, const float16_t s, const uint_fast8_t rndMode); void qrotg(float128_t *a, float128_t *b, float128_t *c, float128_t *s, const uint_fast8_t rndMode); diff --git a/src/blas/level1/hsdot.c b/src/blas/level1/hsdot.c index ac9ed3e..3fcfe48 100644 --- a/src/blas/level1/hsdot.c +++ b/src/blas/level1/hsdot.c @@ -1,9 +1,17 @@ #include "softblas.h" -float32_t hsdot(const uint16_t N, const float16_t alpha, const float16_t *X, const uint16_t incX, const float16_t *Y, const uint16_t incY) { - float32_t dot = alpha; - for (uint64_t i = N; i; i--, X += incX, Y += incY) { - dot = f32_add(dot, f32_mul(f16_to_f32(*X), f16_to_f32(*Y))); +// hsdot: dot product of half-precision vectors accumulated in single +// precision, plus the single-precision bias (alpha widened from half). +float32_t hsdot(const uint64_t N, const float16_t alpha, const float16_t *X, const int64_t incX, const float16_t *Y, const int64_t incY, const uint_fast8_t rndMode) { + _set_rounding(rndMode); + float32_t dot = f16_to_f32(alpha); + int64_t ix = 0, iy = 0; + if (incX < 0) ix = (-N + 1) * incX; + if (incY < 0) iy = (-N + 1) * incY; + for (uint64_t i = 0; i < N; i++) { + dot = f32_add(dot, f32_mul(f16_to_f32(X[ix]), f16_to_f32(Y[iy]))); + ix += incX; + iy += incY; } - return(dot); + return nan_unify_s(dot); } diff --git a/src/blas/level1/sdsdot.c b/src/blas/level1/sdsdot.c index 0356042..b42ae95 100644 --- a/src/blas/level1/sdsdot.c +++ b/src/blas/level1/sdsdot.c @@ -1,17 +1,17 @@ #include "softblas.h" -float32_t sdsdot(const uint64_t N, const float32_t alpha, const float32_t *SX, const int64_t incX, const float32_t *SY, const int64_t incY) { +// sdsdot: dot product of single-precision vectors accumulated in double +// precision, plus the single-precision bias alpha. Returns single precision. +float32_t sdsdot(const uint64_t N, const float32_t alpha, const float32_t *SX, const int64_t incX, const float32_t *SY, const int64_t incY, const uint_fast8_t rndMode) { + _set_rounding(rndMode); float64_t dsdot = f32_to_f64(alpha); - - int64_t ix = 0; - int64_t iy = 0; - if (incX < 0) ix = (-n + 1) * incX; - if (incY < 0) iy = (-n + 1) * incY; - for (uint64_t i = 0; i < n; i++) { + int64_t ix = 0, iy = 0; + if (incX < 0) ix = (-N + 1) * incX; + if (incY < 0) iy = (-N + 1) * incY; + for (uint64_t i = 0; i < N; i++) { dsdot = f64_add(dsdot, f64_mul(f32_to_f64(SX[ix]), f32_to_f64(SY[iy]))); ix += incX; iy += incY; } - - return(f64_to_f32(dot)); + return nan_unify_s(f64_to_f32(dsdot)); } diff --git a/tests/blas/include/test.h b/tests/blas/include/test.h index f2f428d..feae3a7 100644 --- a/tests/blas/include/test.h +++ b/tests/blas/include/test.h @@ -471,5 +471,8 @@ MunitResult test_srotm_basic(const MunitParameter params[], void* u); MunitResult test_srotmg_basic(const MunitParameter params[], void* u); MunitResult test_srotmg_flag_neg2(const MunitParameter params[], void* u); MunitResult test_drotmg_basic(const MunitParameter params[], void* u); +// Extended-precision dot (test_sdsdot.c) +MunitResult test_sdsdot_basic(const MunitParameter params[], void* u); +MunitResult test_hsdot_basic(const MunitParameter params[], void* u); #endif // TEST_H diff --git a/tests/blas/level1/test_sdsdot.c b/tests/blas/level1/test_sdsdot.c new file mode 100644 index 0000000..024a343 --- /dev/null +++ b/tests/blas/level1/test_sdsdot.c @@ -0,0 +1,22 @@ +#include "test.h" + +// sdsdot: alpha + sum(x*y) in double precision. 1 + 2*4 + 3*5 = 24. +MunitResult test_sdsdot_basic(const MunitParameter params[], void* u) { + const float32_t alpha = { SB_REAL32_ONE }; + float32_t* SX = svec((float[]){2.0f, 3.0f}, 2); + float32_t* SY = svec((float[]){4.0f, 5.0f}, 2); + float32_t r = sdsdot(2, alpha, SX, 1, SY, 1, 'n'); + assert_ulong(r.v, ==, 0x41c00000u); // 24 + free(SX); free(SY); + return MUNIT_OK; +} +// hsdot: half inputs, single accumulation. 1 + 2*4 + 3*5 = 24. +MunitResult test_hsdot_basic(const MunitParameter params[], void* u) { + const float16_t alpha = { SB_REAL16_ONE }; + float16_t* X = hvec((uint16_t[]){0x4000, 0x4200}, 2); // 2, 3 + float16_t* Y = hvec((uint16_t[]){0x4400, 0x4500}, 2); // 4, 5 + float32_t r = hsdot(2, alpha, X, 1, Y, 1, 'n'); + assert_ulong(r.v, ==, 0x41c00000u); // 24 (single) + free(X); free(Y); + return MUNIT_OK; +} diff --git a/tests/test_all.c b/tests/test_all.c index a1f1296..9859246 100644 --- a/tests/test_all.c +++ b/tests/test_all.c @@ -36,6 +36,9 @@ int main(int argc, char* argv[MUNIT_ARRAY_PARAM(argc + 1)]) { {"/test_sdot_12345", test_sdot_12345, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/test_sdot_stride", test_sdot_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/test_sdot_neg_stride", test_sdot_neg_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + + {"/test_sdsdot_basic", test_sdsdot_basic, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, + {"/test_hsdot_basic", test_hsdot_basic, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/test_snrm2_0", test_snrm2_0, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/test_snrm2_12345", test_snrm2_12345, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL}, {"/test_snrm2_stride", test_snrm2_stride, NULL, NULL, MUNIT_TEST_OPTION_NONE, NULL},