@@ -49,12 +49,13 @@ def llm_remote_text_df(session, llm_remote_text_pandas_df):
|
49 | 49 | return session.read_pandas(llm_remote_text_pandas_df)
|
50 | 50 |
|
51 | 51 |
|
| 52 | +@pytest.mark.flaky(retries=2) |
52 | 53 | def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df):
|
53 | 54 | model = bigframes.ml.llm.PaLM2TextGenerator(
|
54 | 55 | model_name="text-bison", max_iterations=1
|
55 | 56 | )
|
56 | 57 |
|
57 |
| -df = llm_fine_tune_df_default_index.dropna() |
| 58 | +df = llm_fine_tune_df_default_index.dropna().sample(n=100) |
58 | 59 | X_train = df[["prompt"]]
|
59 | 60 | y_train = df[["label"]]
|
60 | 61 | model.fit(X_train, y_train)
|
@@ -70,6 +71,7 @@ def test_llm_palm_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_
|
70 | 71 | # TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept
|
71 | 72 |
|
72 | 73 |
|
| 74 | +@pytest.mark.flaky(retries=2) |
73 | 75 | def test_llm_palm_score(llm_fine_tune_df_default_index):
|
74 | 76 | model = bigframes.ml.llm.PaLM2TextGenerator(model_name="text-bison")
|
75 | 77 |
|
@@ -89,6 +91,7 @@ def test_llm_palm_score(llm_fine_tune_df_default_index):
|
89 | 91 | assert all(col in score_result_col for col in expected_col)
|
90 | 92 |
|
91 | 93 |
|
| 94 | +@pytest.mark.flaky(retries=2) |
92 | 95 | def test_llm_palm_score_params(llm_fine_tune_df_default_index):
|
93 | 96 | model = bigframes.ml.llm.PaLM2TextGenerator(
|
94 | 97 | model_name="text-bison", max_iterations=1
|
@@ -102,12 +105,10 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
|
102 | 105 | ).to_pandas()
|
103 | 106 | score_result_col = score_result.columns.to_list()
|
104 | 107 | expected_col = [
|
105 |
| -"trial_id", |
106 | 108 | "precision",
|
107 | 109 | "recall",
|
108 |
| -"accuracy", |
109 | 110 | "f1_score",
|
110 |
| -"log_loss", |
111 |
| -"roc_auc", |
| 111 | +"label", |
| 112 | +"evaluation_status", |
112 | 113 | ]
|
113 | 114 | assert all(col in score_result_col for col in expected_col)
|
0 commit comments