(旧存档) 牛客多校第一场-D题
D 凸优化
题意
给出正定的实对称矩阵$\boldsymbol{A} (\boldsymbol{A}\in \mathbb{R}^{n\times n})$和向量$\boldsymbol{b}(\boldsymbol{b} \in \mathbb{R}^{n\times 1})$,向量$\boldsymbol{x} (\boldsymbol{x} \in \mathbb{R}^{n\times 1})$满足约束:$\boldsymbol{x}^T \boldsymbol{A} \boldsymbol{x} \le 1 $,试求$\max \limits_{\boldsymbol {x}} (\boldsymbol {b}\boldsymbol {x}^T)^2$
思路
构造拉格朗日函数,$L(\boldsymbol{x}, y) = \boldsymbol{b}\boldsymbol{x}^T - y(\boldsymbol{x}^T \boldsymbol{A} \boldsymbol{x} - 1)$,其中 $ y \ge 0 $,那么
$\max \limits_{\boldsymbol{x}} \boldsymbol{b}\boldsymbol{x}^T = \max \limits_{\boldsymbol{x}} \max \limits_{y \le 0}L(\boldsymbol{x}, y)$
首先对$y$求偏导,得:
$\frac{\partial L}{\partial y} = \boldsymbol{x}^T \boldsymbol{A} \boldsymbol{x} - 1 = 0$
得到取最大值时,有:$\boldsymbol{x}^T \boldsymbol{A} \boldsymbol{x} = 1$
由拉格朗日对偶性,可得: $\max \limits_{\boldsymbol{x}} \boldsymbol{b}\boldsymbol{x}^T = \max \limits_{y \le 0}\max \limits_{\boldsymbol{x}} L(\boldsymbol{x}, y)$
再对$\boldsymbol x$求偏导,得:
$\frac{\partial L}{\partial x} = \boldsymbol{b} - 2y\boldsymbol A\boldsymbol x = 0$
得到:$\boldsymbol x = -\frac{\boldsymbol A^{-1}}{2y}\boldsymbol b $,联立 $\boldsymbol{x}^T \boldsymbol{A} \boldsymbol{x} = 1$,可得,取最大值时:$\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b}=4y^2$
同样地,取最大值时:$\boldsymbol {b}\boldsymbol {x}^T = -\frac{1}{2y}\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b}$,故:
$\max \limits_{\boldsymbol {x}} (\boldsymbol {b}\boldsymbol {x}^T)^2 = (-\frac{1}{2y}\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b})\cdot(-\frac{1}{2y}\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b})=\frac{1}{4y^2}(\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b})(\boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b})$ $= \boldsymbol{b}^{T}\boldsymbol{A}^{-1}\boldsymbol{b}$
高斯消元求逆即可,时间复杂度:$O(n^3)$
代码
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
const int N = 205, MOD = 998244353;
using ll = long long;
using result = ll;
result a[N][2 * N];
ll quick_pow(ll a, ll p) {
ll ret = 1;
while(p) {
if (p & 1) ret = ret * a % MOD;
a = a * a % MOD;
p >>= 1;
}
return ret;
}
ll divm(ll a, ll b) {
return a * quick_pow(b, MOD - 2) % MOD;
}
int guass(int n) {
for (int i = 0; i < n; ++i) {
result pivot = 0;
int u = 0;
for (int j = i; j < n; ++j) {
if (std::fabs(pivot) < std::fabs(a[j][i])) {
u = j;
pivot = a[j][i];
}
}
if (pivot == 0) return false;
std::swap(a[i], a[u]);
for (int j = 0; j < n; ++j) {
if (i != j && a[j][i]) {
result ratio = divm(a[j][i], a[i][i]);
for (int k = i; k <= 2 * n; ++k) {
a[j][k] = (a[j][k] - ratio * a[i][k]) % MOD;
a[j][k] += (a[j][k] < 0) * MOD;
}
}
}
}
return true;
}
ll invA[N][N];
ll b[N];
ll temp[N];
ll calc(int n) {
for (int i = 0; i < n; ++i) {
ll val = 0;
for (int k = 0; k < n; ++k) {
val = (val + b[k] * invA[i][k]) % MOD;
}
temp[i] = val;
}
ll ret = 0;
for (int i = 0; i < n; ++i) {
ret = (ret + temp[i] * b[i]) % MOD;
}
return ret;
}
int main() {
int n;
while(~scanf("%d", &n)){
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
scanf("%lld", &a[i][j]);
}
}
for (int i = 0; i < n; ++i) {
for (int j = n; j <= 2 * n; ++j) {
a[i][j] = (j == i + n);
}
}
for (int i = 0; i < n; ++i) {
scanf("%lld", &b[i]);
}
guass(n);
for (int i = 0; i < n; ++i) {
for (int j = n; j < 2 * n; ++j) {
invA[i][j - n] = divm(a[i][j], a[i][i]);
}
}
printf("%lld\n", calc(n));
}
return 0;
}