|
88 | 88 | "name": "publishers/google/models/chat-bison",
|
89 | 89 | "version_id": "001",
|
90 | 90 | "open_source_category": "PROPRIETARY",
|
91 |
| -"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.PUBLIC_PREVIEW, |
| 91 | +"launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, |
92 | 92 | "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/chat-bison@001",
|
93 | 93 | "predict_schemata": {
|
94 | 94 | "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml",
|
@@ -792,6 +792,139 @@ def test_chat(self):
|
792 | 792 | gca_predict_response2 = gca_prediction_service.PredictResponse()
|
793 | 793 | gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2)
|
794 | 794 |
|
| 795 | +with mock..object( |
| 796 | +target=prediction_service_client.PredictionServiceClient, |
| 797 | +attribute="predict", |
| 798 | +return_value=gca_predict_response2, |
| 799 | +): |
| 800 | +message_text2 = "When were these books published?" |
| 801 | +expected_response2 = _TEST_CHAT_GENERATION_PREDICTION2["candidates"][0][ |
| 802 | +"content" |
| 803 | +] |
| 804 | +response = chat.send_message(message_text2, temperature=0.1) |
| 805 | +assert response.text == expected_response2 |
| 806 | +assert len(chat.message_history) == 6 |
| 807 | +assert chat.message_history[4].author == chat.USER_AUTHOR |
| 808 | +assert chat.message_history[4].content == message_text2 |
| 809 | +assert chat.message_history[5].author == chat.MODEL_AUTHOR |
| 810 | +assert chat.message_history[5].content == expected_response2 |
| 811 | + |
| 812 | +# Validating the parameters |
| 813 | +chat_temperature = 0.1 |
| 814 | +chat_max_output_tokens = 100 |
| 815 | +chat_top_k = 1 |
| 816 | +chat_top_p = 0.1 |
| 817 | +message_temperature = 0.2 |
| 818 | +message_max_output_tokens = 200 |
| 819 | +message_top_k = 2 |
| 820 | +message_top_p = 0.2 |
| 821 | + |
| 822 | +chat2 = model.start_chat( |
| 823 | +temperature=chat_temperature, |
| 824 | +max_output_tokens=chat_max_output_tokens, |
| 825 | +top_k=chat_top_k, |
| 826 | +top_p=chat_top_p, |
| 827 | +) |
| 828 | + |
| 829 | +gca_predict_response3 = gca_prediction_service.PredictResponse() |
| 830 | +gca_predict_response3.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1) |
| 831 | + |
| 832 | +with mock..object( |
| 833 | +target=prediction_service_client.PredictionServiceClient, |
| 834 | +attribute="predict", |
| 835 | +return_value=gca_predict_response3, |
| 836 | +) as mock_predict3: |
| 837 | +chat2.send_message("Are my favorite movies based on a book series?") |
| 838 | +prediction_parameters = mock_predict3.call_args[1]["parameters"] |
| 839 | +assert prediction_parameters["temperature"] == chat_temperature |
| 840 | +assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens |
| 841 | +assert prediction_parameters["topK"] == chat_top_k |
| 842 | +assert prediction_parameters["topP"] == chat_top_p |
| 843 | + |
| 844 | +chat2.send_message( |
| 845 | +"Are my favorite movies based on a book series?", |
| 846 | +temperature=message_temperature, |
| 847 | +max_output_tokens=message_max_output_tokens, |
| 848 | +top_k=message_top_k, |
| 849 | +top_p=message_top_p, |
| 850 | +) |
| 851 | +prediction_parameters = mock_predict3.call_args[1]["parameters"] |
| 852 | +assert prediction_parameters["temperature"] == message_temperature |
| 853 | +assert prediction_parameters["maxDecodeSteps"] == message_max_output_tokens |
| 854 | +assert prediction_parameters["topK"] == message_top_k |
| 855 | +assert prediction_parameters["topP"] == message_top_p |
| 856 | + |
| 857 | +def test_chat_ga(self): |
| 858 | +"""Tests the chat generation model.""" |
| 859 | +aiplatform.init( |
| 860 | +project=_TEST_PROJECT, |
| 861 | +location=_TEST_LOCATION, |
| 862 | +) |
| 863 | +with mock..object( |
| 864 | +target=model_garden_service_client.ModelGardenServiceClient, |
| 865 | +attribute="get_publisher_model", |
| 866 | +return_value=gca_publisher_model.PublisherModel( |
| 867 | +_CHAT_BISON_PUBLISHER_MODEL_DICT |
| 868 | +), |
| 869 | +) as mock_get_publisher_model: |
| 870 | +model = language_models.ChatModel.from_pretrained("chat-bison@001") |
| 871 | + |
| 872 | +mock_get_publisher_model.assert_called_once_with( |
| 873 | +name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY |
| 874 | +) |
| 875 | + |
| 876 | +chat = model.start_chat( |
| 877 | +context=""" |
| 878 | +My name is Ned. |
| 879 | +You are my personal assistant. |
| 880 | +My favorite movies are Lord of the Rings and Hobbit. |
| 881 | +""", |
| 882 | +examples=[ |
| 883 | +language_models.InputOutputTextPair( |
| 884 | +input_text="Who do you work for?", |
| 885 | +output_text="I work for Ned.", |
| 886 | +), |
| 887 | +language_models.InputOutputTextPair( |
| 888 | +input_text="What do I like?", |
| 889 | +output_text="Ned likes watching movies.", |
| 890 | +), |
| 891 | +], |
| 892 | +message_history=[ |
| 893 | +language_models.ChatMessage( |
| 894 | +author=preview_language_models.ChatSession.USER_AUTHOR, |
| 895 | +content="Question 1?", |
| 896 | +), |
| 897 | +language_models.ChatMessage( |
| 898 | +author=preview_language_models.ChatSession.MODEL_AUTHOR, |
| 899 | +content="Answer 1.", |
| 900 | +), |
| 901 | +], |
| 902 | +temperature=0.0, |
| 903 | +) |
| 904 | + |
| 905 | +gca_predict_response1 = gca_prediction_service.PredictResponse() |
| 906 | +gca_predict_response1.predictions.append(_TEST_CHAT_GENERATION_PREDICTION1) |
| 907 | + |
| 908 | +with mock..object( |
| 909 | +target=prediction_service_client.PredictionServiceClient, |
| 910 | +attribute="predict", |
| 911 | +return_value=gca_predict_response1, |
| 912 | +): |
| 913 | +message_text1 = "Are my favorite movies based on a book series?" |
| 914 | +expected_response1 = _TEST_CHAT_GENERATION_PREDICTION1["candidates"][0][ |
| 915 | +"content" |
| 916 | +] |
| 917 | +response = chat.send_message(message_text1) |
| 918 | +assert response.text == expected_response1 |
| 919 | +assert len(chat.message_history) == 4 |
| 920 | +assert chat.message_history[2].author == chat.USER_AUTHOR |
| 921 | +assert chat.message_history[2].content == message_text1 |
| 922 | +assert chat.message_history[3].author == chat.MODEL_AUTHOR |
| 923 | +assert chat.message_history[3].content == expected_response1 |
| 924 | + |
| 925 | +gca_predict_response2 = gca_prediction_service.PredictResponse() |
| 926 | +gca_predict_response2.predictions.append(_TEST_CHAT_GENERATION_PREDICTION2) |
| 927 | + |
795 | 928 | with mock..object(
|
796 | 929 | target=prediction_service_client.PredictionServiceClient,
|
797 | 930 | attribute="predict",
|
|
0 commit comments