Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def test_from_optimization() -> None:
assert os.path.exists(csv_file)

assert bern_gqs.draws().shape == (1, 1, 10)
assert bern_gqs.draws(inc_sample=True).shape == (1, 1, 12)
assert bern_gqs.draws(inc_sample=True).shape == (1, 1, 13)

# draws_pd()
assert bern_gqs.draws_pd().shape == (1, 13)
Expand Down Expand Up @@ -611,18 +611,18 @@ def test_opt_save_iterations(caplog: pytest.LogCaptureFixture) -> None:
assert bern_gqs.draws(inc_warmup=True, inc_sample=True).shape == (
iters,
1,
12,
13,
)

assert bern_gqs.draws(concat_chains=True).shape == (1, 10)
assert bern_gqs.draws(concat_chains=True, inc_sample=True).shape == (1, 12)
assert bern_gqs.draws(concat_chains=True, inc_sample=True).shape == (1, 13)
assert bern_gqs.draws(concat_chains=True, inc_warmup=True).shape == (
iters,
10,
)
assert bern_gqs.draws(
concat_chains=True, inc_warmup=True, inc_sample=True
).shape == (iters, 12)
).shape == (iters, 13)

# stan_variable
theta = bern_gqs.stan_variable(var='theta')
Expand Down
26 changes: 13 additions & 13 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def test_rosenbrock(caplog: pytest.LogCaptureFixture) -> None:
assert 'CmdStanMLE: model=rosenbrock' in repr(mle)
assert 'method=optimize' in repr(mle)
assert mle.converged
assert mle.column_names == ('lp__', 'x', 'y')
assert mle.column_names == ('lp__', 'converged__', 'x', 'y')
np.testing.assert_almost_equal(mle.stan_variable('x'), 1, decimal=3)
np.testing.assert_almost_equal(mle.stan_variable('y'), 1, decimal=3)
np.testing.assert_almost_equal(
mle.optimized_params_pd['x'][0], 1, decimal=3
)
np.testing.assert_almost_equal(mle.optimized_params_np[1], 1, decimal=3)
np.testing.assert_almost_equal(mle.optimized_params_np[2], 1, decimal=3)
np.testing.assert_almost_equal(mle.optimized_params_dict['x'], 1, decimal=3)
with caplog.at_level(logging.WARNING):
assert mle.optimized_iterations_np is None
Expand Down Expand Up @@ -131,8 +131,8 @@ def test_rosenbrock(caplog: pytest.LogCaptureFixture) -> None:
np.testing.assert_almost_equal(mle.stan_variable('x'), 1, decimal=3)
np.testing.assert_almost_equal(mle.stan_variable('y'), 1, decimal=3)

assert mle.optimized_params_np.shape == (3,)
np.testing.assert_almost_equal(mle.optimized_params_np[1], 1, decimal=3)
assert mle.optimized_params_np.shape == (4,)
np.testing.assert_almost_equal(mle.optimized_params_np[2], 1, decimal=3)
np.testing.assert_almost_equal(
mle.optimized_params_pd['x'][0], 1, decimal=3
)
Expand All @@ -143,7 +143,7 @@ def test_rosenbrock(caplog: pytest.LogCaptureFixture) -> None:
mle.optimized_iterations_np[0, 1]
!= mle.optimized_iterations_np[last_iter, 1]
)
for i in range(3):
for i in range(4):
assert (
mle.optimized_params_np[i]
== mle.optimized_iterations_np[last_iter, i]
Expand All @@ -162,7 +162,7 @@ def test_eight_schools(caplog: pytest.LogCaptureFixture) -> None:
assert 'method=optimize' in repr(mle)
assert not mle.converged
with caplog.at_level(logging.WARNING):
assert mle.optimized_params_pd.shape == (1, 11)
assert mle.optimized_params_pd.shape == (1, 12)
check_present(
caplog,
(
Expand Down Expand Up @@ -294,15 +294,15 @@ def test_optimize_good() -> None:
# test numpy output
assert isinstance(mle.optimized_params_np, np.ndarray)
np.testing.assert_almost_equal(mle.optimized_params_np[0], -5, decimal=2)
np.testing.assert_almost_equal(mle.optimized_params_np[1], 0.2, decimal=3)
np.testing.assert_almost_equal(mle.optimized_params_np[2], 0.2, decimal=3)

# test pandas output
assert mle.optimized_params_np[0] == mle.optimized_params_pd['lp__'][0]
assert mle.optimized_params_np[1] == mle.optimized_params_pd['theta'][0]
assert mle.optimized_params_np[2] == mle.optimized_params_pd['theta'][0]

# test dict output
assert mle.optimized_params_np[0] == mle.optimized_params_dict['lp__']
assert mle.optimized_params_np[1] == mle.optimized_params_dict['theta']
assert mle.optimized_params_np[2] == mle.optimized_params_dict['theta']


def test_negative_parameter_values() -> None:
Expand Down Expand Up @@ -530,14 +530,14 @@ def test_optimize_good_dict() -> None:
)
# test numpy output
np.testing.assert_almost_equal(mle.optimized_params_np[0], -5, decimal=2)
np.testing.assert_almost_equal(mle.optimized_params_np[1], 0.2, decimal=3)
np.testing.assert_almost_equal(mle.optimized_params_np[2], 0.2, decimal=3)


def test_optimize_rosenbrock() -> None:
stan = os.path.join(DATAFILES_PATH, 'optimize', 'rosenbrock.stan')
rose_model = CmdStanModel(stan_file=stan)
mle = rose_model.optimize(seed=1239812093, inits=None, algorithm='BFGS')
assert mle.column_names == ('lp__', 'x', 'y')
assert mle.column_names == ('lp__', 'converged__', 'x', 'y')
np.testing.assert_almost_equal(mle.optimized_params_dict['x'], 1, decimal=3)
np.testing.assert_almost_equal(mle.optimized_params_dict['y'], 1, decimal=3)

Expand All @@ -546,7 +546,7 @@ def test_optimize_no_data() -> None:
stan = os.path.join(DATAFILES_PATH, 'optimize', 'no_data.stan')
rose_model = CmdStanModel(stan_file=stan)
mle = rose_model.optimize(seed=1239812093)
assert mle.column_names == ('lp__', 'a')
assert mle.column_names == ('lp__', 'converged__', 'a')
np.testing.assert_almost_equal(mle.optimized_params_dict['a'], 0, decimal=3)


Expand Down Expand Up @@ -599,7 +599,7 @@ def test_exe_only() -> None:
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
mle = bern2_model.optimize(data=jdata)
assert mle.optimized_params_np[0] == mle.optimized_params_dict['lp__']
assert mle.optimized_params_np[1] == mle.optimized_params_dict['theta']
assert mle.optimized_params_np[2] == mle.optimized_params_dict['theta']


def test_complex_output() -> None:
Expand Down
Loading