|
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 | from botorch.acquisition.acquisition import AcquisitionFunction
|
18 |
| -from botorch.acquisition.analytic import AnalyticAcquisitionFunction |
19 | 18 | from botorch.acquisition.objective import GenericMCObjective
|
20 |
| -from botorch.exceptions import UnsupportedError |
| 19 | +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper |
21 | 20 | from torch import Tensor
|
22 | 21 |
|
23 | 22 |
|
@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
|
139 | 138 | return regularization_term
|
140 | 139 |
|
141 | 140 |
|
142 |
| -class PenalizedAcquisitionFunction(AcquisitionFunction): |
| 141 | +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): |
143 | 142 | r"""Single-outcome acquisition function regularized by the given penalty.
|
144 | 143 |
|
145 | 144 | The usage is similar to:
|
@@ -161,29 +160,16 @@ def __init__(
|
161 | 160 | penalty_func: The regularization function.
|
162 | 161 | regularization_parameter: Regularization parameter used in optimization.
|
163 | 162 | """
|
164 |
| -super().__init__(model=raw_acqf.model) |
165 |
| -self.raw_acqf = raw_acqf |
| 163 | +AcquisitionFunction.__init__(self, model=raw_acqf.model) |
| 164 | +AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) |
166 | 165 | self.penalty_func = penalty_func
|
167 | 166 | self.regularization_parameter = regularization_parameter
|
168 | 167 |
|
169 | 168 | def forward(self, X: Tensor) -> Tensor:
|
170 |
| -raw_value = self.raw_acqf(X=X) |
| 169 | +raw_value = self.acq_func(X=X) |
171 | 170 | penalty_term = self.penalty_func(X)
|
172 | 171 | return raw_value - self.regularization_parameter * penalty_term
|
173 | 172 |
|
174 |
| -@property |
175 |
| -def X_pending(self) -> Optional[Tensor]: |
176 |
| -return self.raw_acqf.X_pending |
177 |
| - |
178 |
| -def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: |
179 |
| -if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): |
180 |
| -self.raw_acqf.set_X_pending(X_pending=X_pending) |
181 |
| -else: |
182 |
| -raise UnsupportedError( |
183 |
| -"The raw acquisition function is Analytic and does not account " |
184 |
| -"for X_pending yet." |
185 |
| -) |
186 |
| - |
187 | 173 |
|
188 | 174 | def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
|
189 | 175 | r"""Computes the group lasso regularization function for the given point.
|
|
0 commit comments