File tree

4 files changed

+124
-50
lines changed

4 files changed

+124
-50
lines changed
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def predict(
244244
245245
Args:
246246
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
247-
Input DataFrame or Series, which contains only one column of prompts.
247+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
248248
Prompts can include preamble, questions, suggestions, instructions, or examples.
249249
250250
temperature (float, default 0.0):
@@ -307,14 +307,10 @@ def predict(
307307

308308
(X,) = utils.convert_to_dataframe(X)
309309

310-
if len(X.columns) != 1:
311-
raise ValueError(
312-
f"Only support one column as input. {constants.FEEDBACK_LINK}"
313-
)
314-
315-
# BQML identified the column by name
316-
col_label = cast(blocks.Label, X.columns[0])
317-
X = X.rename(columns={col_label: "prompt"})
310+
if len(X.columns) == 1:
311+
# BQML identified the column by name
312+
col_label = cast(blocks.Label, X.columns[0])
313+
X = X.rename(columns={col_label: "prompt"})
318314

319315
options = {
320316
"temperature": temperature,
@@ -522,7 +518,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
522518
523519
Args:
524520
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
525-
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
521+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
526522
527523
Returns:
528524
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -531,14 +527,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
531527
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
532528
(X,) = utils.convert_to_dataframe(X)
533529

534-
if len(X.columns) != 1:
535-
raise ValueError(
536-
f"Only support one column as input. {constants.FEEDBACK_LINK}"
537-
)
538-
539-
# BQML identified the column by name
540-
col_label = cast(blocks.Label, X.columns[0])
541-
X = X.rename(columns={col_label: "content"})
530+
if len(X.columns) == 1:
531+
# BQML identified the column by name
532+
col_label = cast(blocks.Label, X.columns[0])
533+
X = X.rename(columns={col_label: "content"})
542534

543535
options = {
544536
"flatten_json_output": True,
@@ -679,7 +671,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
679671
680672
Args:
681673
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
682-
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
674+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction.
683675
684676
Returns:
685677
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
@@ -688,14 +680,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
688680
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
689681
(X,) = utils.convert_to_dataframe(X)
690682

691-
if len(X.columns) != 1:
692-
raise ValueError(
693-
f"Only support one column as input. {constants.FEEDBACK_LINK}"
694-
)
695-
696-
# BQML identified the column by name
697-
col_label = cast(blocks.Label, X.columns[0])
698-
X = X.rename(columns={col_label: "content"})
683+
if len(X.columns) == 1:
684+
# BQML identified the column by name
685+
col_label = cast(blocks.Label, X.columns[0])
686+
X = X.rename(columns={col_label: "content"})
699687

700688
options = {
701689
"flatten_json_output": True,
@@ -893,7 +881,7 @@ def predict(
893881
894882
Args:
895883
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
896-
Input DataFrame or Series, which contains only one column of prompts.
884+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
897885
Prompts can include preamble, questions, suggestions, instructions, or examples.
898886
899887
temperature (float, default 0.9):
@@ -938,14 +926,10 @@ def predict(
938926

939927
(X,) = utils.convert_to_dataframe(X)
940928

941-
if len(X.columns) != 1:
942-
raise ValueError(
943-
f"Only support one column as input. {constants.FEEDBACK_LINK}"
944-
)
945-
946-
# BQML identified the column by name
947-
col_label = cast(blocks.Label, X.columns[0])
948-
X = X.rename(columns={col_label: "prompt"})
929+
if len(X.columns) == 1:
930+
# BQML identified the column by name
931+
col_label = cast(blocks.Label, X.columns[0])
932+
X = X.rename(columns={col_label: "prompt"})
949933

950934
options = {
951935
"temperature": temperature,
@@ -1181,7 +1165,7 @@ def predict(
11811165
11821166
Args:
11831167
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
1184-
Input DataFrame or Series, which contains only one column of prompts.
1168+
Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction.
11851169
Prompts can include preamble, questions, suggestions, instructions, or examples.
11861170
11871171
max_output_tokens (int, default 128):
@@ -1222,14 +1206,10 @@ def predict(
12221206

12231207
(X,) = utils.convert_to_dataframe(X)
12241208

1225-
if len(X.columns) != 1:
1226-
raise ValueError(
1227-
f"Only support one column as input. {constants.FEEDBACK_LINK}"
1228-
)
1229-
1230-
# BQML identified the column by name
1231-
col_label = cast(blocks.Label, X.columns[0])
1232-
X = X.rename(columns={col_label: "prompt"})
1209+
if len(X.columns) == 1:
1210+
# BQML identified the column by name
1211+
col_label = cast(blocks.Label, X.columns[0])
1212+
X = X.rename(columns={col_label: "prompt"})
12331213

12341214
options = {
12351215
"max_output_tokens": max_output_tokens,
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,27 @@ def test_claude3_text_generator_predict_with_params_success(
156156
utils.check_pandas_df_schema_and_index(
157157
df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False
158158
)
159+
160+
161+
@pytest.mark.parametrize(
162+
"model_name",
163+
("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"),
164+
)
165+
@pytest.mark.flaky(retries=3, delay=120)
166+
def test_claude3_text_generator_predict_multi_col_success(
167+
llm_text_df, model_name, session, session_us_east5, bq_connection
168+
):
169+
if model_name in ("claude-3-5-sonnet", "claude-3-opus"):
170+
session = session_us_east5
171+
172+
llm_text_df["additional_col"] = 1
173+
claude3_text_generator_model = llm.Claude3TextGenerator(
174+
model_name=model_name, connection_name=bq_connection, session=session
175+
)
176+
df = claude3_text_generator_model.predict(llm_text_df).to_pandas()
177+
utils.check_pandas_df_schema_and_index(
178+
df,
179+
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
180+
index=3,
181+
col_exact=False,
182+
)
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
from bigframes.ml import llm
18+
import bigframes.pandas as bpd
1819
from tests.system import utils
1920

2021

@@ -166,6 +167,20 @@ def test_text_generator_predict_arbitrary_col_label_success(
166167
)
167168

168169

170+
@pytest.mark.flaky(retries=2)
171+
def test_text_generator_predict_multiple_cols_success(
172+
palm2_text_generator_model, llm_text_df: bpd.DataFrame
173+
):
174+
df = llm_text_df.assign(additional_col=1)
175+
pd_df = palm2_text_generator_model.predict(df).to_pandas()
176+
utils.check_pandas_df_schema_and_index(
177+
pd_df,
178+
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
179+
index=3,
180+
col_exact=False,
181+
)
182+
183+
169184
@pytest.mark.flaky(retries=2)
170185
def test_text_generator_predict_with_params_success(
171186
palm2_text_generator_model, llm_text_df
@@ -212,11 +227,33 @@ def test_text_embedding_generator_predict_default_params_success(
212227
model_name=model_name, connection_name=bq_connection, session=session
213228
)
214229
df = text_embedding_model.predict(llm_text_df).to_pandas()
215-
assert df.shape == (3, 4)
216-
assert "ml_generate_embedding_result" in df.columns
217-
series = df["ml_generate_embedding_result"]
218-
value = series[0]
219-
assert len(value) == 768
230+
utils.check_pandas_df_schema_and_index(
231+
df, columns=utils.ML_GENERATE_EMBEDDING_OUTPUT, index=3, col_exact=False
232+
)
233+
assert len(df["ml_generate_embedding_result"][0]) == 768
234+
235+
236+
@pytest.mark.parametrize(
237+
"model_name",
238+
("text-embedding-004", "text-multilingual-embedding-002"),
239+
)
240+
@pytest.mark.flaky(retries=2)
241+
def test_text_embedding_generator_multi_cols_predict_success(
242+
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
243+
):
244+
df = llm_text_df.assign(additional_col=1)
245+
df = df.rename(columns={"prompt": "content"})
246+
text_embedding_model = llm.TextEmbeddingGenerator(
247+
model_name=model_name, connection_name=bq_connection, session=session
248+
)
249+
pd_df = text_embedding_model.predict(df).to_pandas()
250+
utils.check_pandas_df_schema_and_index(
251+
pd_df,
252+
columns=utils.ML_GENERATE_EMBEDDING_OUTPUT + ["additional_col"],
253+
index=3,
254+
col_exact=False,
255+
)
256+
assert len(pd_df["ml_generate_embedding_result"][0]) == 768
220257

221258

222259
@pytest.mark.parametrize(
@@ -295,6 +332,33 @@ def test_gemini_text_generator_predict_with_params_success(
295332
)
296333

297334

335+
@pytest.mark.parametrize(
336+
"model_name",
337+
(
338+
"gemini-pro",
339+
"gemini-1.5-pro-preview-0514",
340+
"gemini-1.5-flash-preview-0514",
341+
"gemini-1.5-pro-001",
342+
"gemini-1.5-flash-001",
343+
),
344+
)
345+
@pytest.mark.flaky(retries=2)
346+
def test_gemini_text_generator_multi_cols_predict_success(
347+
llm_text_df: bpd.DataFrame, model_name, session, bq_connection
348+
):
349+
df = llm_text_df.assign(additional_col=1)
350+
gemini_text_generator_model = llm.GeminiTextGenerator(
351+
model_name=model_name, connection_name=bq_connection, session=session
352+
)
353+
pd_df = gemini_text_generator_model.predict(df).to_pandas()
354+
utils.check_pandas_df_schema_and_index(
355+
pd_df,
356+
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"],
357+
index=3,
358+
col_exact=False,
359+
)
360+
361+
298362
@pytest.mark.flaky(retries=2)
299363
def test_llm_palm_score(llm_fine_tune_df_default_index):
300364
model = llm.PaLM2TextGenerator(model_name="text-bison")
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
"ml_generate_text_status",
5151
"prompt",
5252
]
53+
ML_GENERATE_EMBEDDING_OUTPUT = [
54+
"ml_generate_embedding_result",
55+
"ml_generate_embedding_statistics",
56+
"ml_generate_embedding_status",
57+
"content",
58+
]
5359

5460

5561
def skip_legacy_pandas(test):

0 commit comments

Comments
 (0)