File tree
Expand file treeCollapse file tree1 file changed
+11
-3
lines changed Expand file treeCollapse file tree1 file changed
+11
-3
lines changed Original file line number | Diff line number | Diff line change |
---|
|
89 | 89 | # %%
|
90 | 90 | df_test = pd.read_csv(csv_test)
|
91 | 91 |
|
92 |
| -predictions = model.predict(csv_test) |
93 |
| -print(predictions[0]) |
| 92 | +dm = TabularClassificationData.from_data_frame( |
| 93 | +predict_data_frame=df_test, |
| 94 | +parameters=datamodule.parameters, |
| 95 | +batch_size=datamodule.batch_size, |
| 96 | +) |
| 97 | +preds = trainer.predict(model, datamodule=dm, output="classes") |
| 98 | +print(preds[0][:10]) |
94 | 99 |
|
95 | 100 | # %%
|
| 101 | +import itertools # noqa: E402] |
| 102 | + |
96 | 103 | import numpy as np # noqa: E402]
|
97 | 104 |
|
98 |
| -assert len(df_test) == len(predictions) |
| 105 | +predictions = list(itertools.chain(*preds)) |
| 106 | +# assert len(df_test) == len(predictions) |
99 | 107 |
|
100 | 108 | df_test["Survived"] = np.argmax(predictions, axis=-1)
|
101 | 109 | df_test.set_index("PassengerId", inplace=True)
|
|
You can’t perform that action at this time.
0 commit comments