File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df):
4949
return session.read_pandas(llm_remote_text_pandas_df)
5050

5151

52+
@pytest.mark.flaky(retries=2)
5253
def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df):
5354
model = bigframes.ml.llm.PaLM2TextGenerator(
5455
model_name="text-bison", max_iterations=1
5556
)
5657

57-
df = llm_fine_tune_df_default_index.dropna()
58+
df = llm_fine_tune_df_default_index.dropna().sample(n=100)
5859
X_train = df[["prompt"]]
5960
y_train = df[["label"]]
6061
model.fit(X_train, y_train)
@@ -70,6 +71,7 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_
7071
# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept
7172

7273

74+
@pytest.mark.flaky(retries=2)
7375
def test_llm_palm_score(llm_fine_tune_df_default_index):
7476
model = bigframes.ml.llm.PaLM2TextGenerator(model_name="text-bison")
7577

@@ -89,6 +91,7 @@ def test_llm_palm_score(llm_fine_tune_df_default_index):
8991
assert all(col in score_result_col for col in expected_col)
9092

9193

94+
@pytest.mark.flaky(retries=2)
9295
def test_llm_palm_score_params(llm_fine_tune_df_default_index):
9396
model = bigframes.ml.llm.PaLM2TextGenerator(
9497
model_name="text-bison", max_iterations=1
@@ -102,12 +105,10 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
102105
).to_pandas()
103106
score_result_col = score_result.columns.to_list()
104107
expected_col = [
105-
"trial_id",
106108
"precision",
107109
"recall",
108-
"accuracy",
109110
"f1_score",
110-
"log_loss",
111-
"roc_auc",
111+
"label",
112+
"evaluation_status",
112113
]
113114
assert all(col in score_result_col for col in expected_col)

0 commit comments

Comments
 (0)