agnes-the-ai-analyst/app/api/where_validator.py
ZdenekSrotyr 2e1dfb7553
feat(v2): claude-driven fetch primitives + 0.14.0 (#102)
Replaces the BigQuery wrap-view pattern with a discovery + scoped-fetch toolkit driven by the analyst's Claude session. Adds /api/v2/{catalog,schema,sample,scan,scan/estimate}, da catalog/schema/describe/fetch/snapshot/disk-info CLI commands, sqlglot-backed WHERE validator, process-local quota tracker, agent rails skill (cli/skills/agnes-data-querying.md). BREAKING: BQ wrap views off by default — set data_source.bigquery.legacy_wrap_views=true for one cycle. Backward-compat field_validator on primary_key. Catalog cache now matches documented 300s TTL with RBAC fresh per request. Cuts release v0.14.0.
2026-04-29 01:07:19 +02:00

244 lines
9.7 KiB
Python

"""WHERE clause validator for /api/v2/scan.
Single security perimeter — every analyst-supplied predicate flows through here
before reaching BigQuery. Allow-list-driven; explicit rejection codes per spec §3.7.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Mapping
import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError
logger = logging.getLogger(__name__)
# Rejection kind codes (stable; used by callers + tests + audit log)
REJECT_PARSE = "parse_error"
REJECT_NESTED_SELECT = "nested_select"
REJECT_MULTI_STATEMENT = "multi_statement"
REJECT_DDL_DML = "ddl_or_dml"
REJECT_CROSS_TABLE = "cross_table_reference"
REJECT_UNKNOWN_FUNCTION = "unknown_function"
REJECT_UNKNOWN_COLUMN = "unknown_column"
REJECT_DISALLOWED_NODE = "disallowed_node"
@dataclass
class WhereValidationError(Exception):
kind: str
message: str
detail: dict | None = None
def __str__(self) -> str:
return f"[{self.kind}] {self.message}"
# Nodes that imply DDL/DML (rejected outright).
_DDL_DML_NODES = (
exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.TruncateTable,
exp.Alter, exp.Create, exp.Copy, exp.Merge,
)
# v1 BigQuery function allow-list (spec §3.7). Stored as upper-case names.
_ALLOW_FUNCTIONS_DATETIME = {
"CURRENT_DATE", "CURRENT_TIMESTAMP", "CURRENT_TIME",
"DATE", "DATETIME", "TIMESTAMP", "TIME",
"DATE_ADD", "DATE_SUB", "DATE_DIFF", "DATE_TRUNC", "EXTRACT",
"FORMAT_DATE", "FORMAT_TIMESTAMP", "PARSE_DATE", "PARSE_TIMESTAMP",
"UNIX_SECONDS", "UNIX_MILLIS",
}
_ALLOW_FUNCTIONS_STRING = {
"CONCAT", "LENGTH", "LOWER", "UPPER", "SUBSTR", "SUBSTRING",
"TRIM", "LTRIM", "RTRIM", "REPLACE",
"STARTS_WITH", "ENDS_WITH", "CONTAINS_SUBSTR",
"REGEXP_CONTAINS", "REGEXP_EXTRACT", "SAFE_CAST",
# sqlglot normalizes some BQ funcs to a canonical SQL name; allow both spellings.
"REGEXP_LIKE", # sqlglot canonical for REGEXP_CONTAINS
}
_ALLOW_FUNCTIONS_MATH = {
"ABS", "CEIL", "FLOOR", "ROUND", "MOD", "POWER", "SQRT",
"LOG", "LN", "EXP", "SIGN", "GREATEST", "LEAST",
}
_ALLOW_FUNCTIONS_CAST = {"CAST"}
_ALLOW_FUNCTIONS_CONDITIONAL = {"IF", "IFNULL", "COALESCE", "NULLIF", "CASE"}
ALLOWED_FUNCTIONS: frozenset[str] = frozenset(
_ALLOW_FUNCTIONS_DATETIME
| _ALLOW_FUNCTIONS_STRING
| _ALLOW_FUNCTIONS_MATH
| _ALLOW_FUNCTIONS_CAST
| _ALLOW_FUNCTIONS_CONDITIONAL
)
def validate_where(
predicate: str,
table_id: str,
schema: Mapping[str, str],
*,
dialect: str = "bigquery",
) -> exp.Expression:
"""Validate a WHERE-clause fragment.
Args:
predicate: SQL fragment (without leading 'WHERE').
table_id: target table id; cross-table references rejected.
schema: {column_name: type} for the target table.
dialect: sqlglot dialect to parse with. Default 'bigquery'. Pass 'duckdb'
(or anything sqlglot supports) when the predicate will be executed
against a local DuckDB scan, so DuckDB-specific syntax parses.
Returns:
Parsed sqlglot expression tree (caller may re-stringify or inspect).
Raises:
WhereValidationError: with .kind set to one of the REJECT_* codes.
"""
if not predicate or not predicate.strip():
raise WhereValidationError(REJECT_PARSE, "empty predicate")
# Multi-statement detection: BQ statements separated by ';' would parse
# as multiple expressions in sqlglot.parse() (returns a list).
try:
statements = sqlglot.parse(f"SELECT 1 FROM t WHERE {predicate}", dialect=dialect)
except ParseError as e:
raise WhereValidationError(REJECT_PARSE, f"parse failed: {e}")
if statements is None or len(statements) != 1 or statements[0] is None:
raise WhereValidationError(REJECT_MULTI_STATEMENT, "multi-statement input not allowed")
select = statements[0]
# A predicate like `1=1 UNION ALL SELECT secret FROM x` parses as a single
# `exp.Union` (not `exp.Select`), and `find(exp.Where)` would return only
# the left side's `1=1` — passing structural checks while the raw predicate
# string still gets concatenated into the final SQL. Reject here.
if not isinstance(select, exp.Select):
raise WhereValidationError(
REJECT_DISALLOWED_NODE,
f"top-level statement must be SELECT, got {type(select).__name__}",
)
where = select.find(exp.Where)
if where is None:
raise WhereValidationError(REJECT_PARSE, "no WHERE expression found in parsed input")
_walk_structural(where, table_id, schema)
return where
def safe_where_predicate(
predicate: str,
table_id: str,
schema: Mapping[str, str],
*,
dialect: str = "bigquery",
) -> str:
"""Validate `predicate` and return the canonical comment-stripped SQL fragment.
Use this everywhere a validated predicate is concatenated into final SQL.
Splicing the raw input string is unsafe: sqlglot strips ``-- ...`` line
comments and ``/* ... */`` block comments during parsing, so a predicate
like ``1=1 --`` validates clean (AST sees just ``1=1``) but the raw
string commented out everything after it in the final SQL — bypassing
LIMIT, ORDER BY, and the server-enforced row caps.
"""
where = validate_where(predicate, table_id, schema, dialect=dialect)
# `where.this` is the expression inside WHERE (without the WHERE keyword).
# `.sql(comments=False)` re-renders the parsed AST and explicitly drops
# `/* ... */` block comments that sqlglot otherwise attaches to nodes
# as metadata (line `--` comments are already lost during parsing).
return where.this.sql(dialect=dialect, comments=False)
def _walk_structural(node: exp.Expression, table_id: str, schema: Mapping[str, str]) -> None:
"""Walk the WHERE AST and reject disallowed structures."""
for sub in node.walk():
# `node.walk()` yields the node itself first; check structural rules.
if isinstance(sub, exp.Subquery) or (isinstance(sub, exp.Select) and sub is not node):
raise WhereValidationError(REJECT_NESTED_SELECT, "nested SELECT/subquery not allowed")
if isinstance(sub, _DDL_DML_NODES):
raise WhereValidationError(REJECT_DDL_DML, f"DDL/DML node {type(sub).__name__} not allowed")
# Cross-table reference detection: any column with a qualifier other than
# the target table_id (or unqualified) is rejected.
for col in node.find_all(exp.Column):
qualifier = col.table # e.g. "other_table" in `other_table.id`
if qualifier and qualifier.lower() != table_id.lower():
raise WhereValidationError(
REJECT_CROSS_TABLE,
f"column {col.sql()} references table {qualifier!r}, expected {table_id!r}",
)
_walk_functions(node)
_walk_columns(node, schema)
def _walk_columns(node: exp.Expression, schema: Mapping[str, str]) -> None:
"""Reject column references not present in the target table's schema."""
known = {c.lower() for c in schema}
for col in node.find_all(exp.Column):
# `col.name` is the leaf column name (e.g. "country_code" in
# "tbl.country_code"). For dotted struct fields like "rec.sub.leaf",
# sqlglot models as nested exp.Dot; v1 only checks top-level names.
leaf = (col.name or "").lower()
if leaf and leaf not in known:
raise WhereValidationError(
REJECT_UNKNOWN_COLUMN,
f"column {col.name!r} not in schema for {col.table!r}",
detail={"column": col.name},
)
def _walk_functions(node: exp.Expression) -> None:
"""Reject function calls outside the allow-list.
sqlglot represents function calls in two ways:
- typed subclasses (e.g. ``exp.Length``, ``exp.StartsWith``, ``exp.SessionUser``,
``exp.Cast``, ``exp.Coalesce``) — canonical SQL name available via ``sql_name()``;
- ``exp.Anonymous`` for unknown built-ins or UDFs — name in ``func.name``.
Both paths funnel into ``ALLOWED_FUNCTIONS``; everything else is rejected.
"""
for func in node.find_all(exp.Func):
# Logical connectors (AND/OR/XOR) inherit exp.Func in sqlglot but are
# operators, not user-callable functions. Skip them.
if isinstance(func, exp.Connector):
continue
if isinstance(func, exp.AggFunc):
raise WhereValidationError(
REJECT_UNKNOWN_FUNCTION,
f"aggregate function not allowed in WHERE: {type(func).__name__}",
detail={"function": type(func).__name__.upper()},
)
# `Anonymous` carries the source name in `func.name`; typed nodes carry
# their canonical SQL name via `sql_name()`. `name` on typed nodes often
# holds the first child's identifier, so we never trust it directly.
if isinstance(func, exp.Anonymous):
name = (func.name or "").upper()
else:
try:
name = (func.sql_name() or "").upper()
except Exception:
name = ""
# If sql_name() returns empty for a typed Func, we can't tell whether
# it's a benign operator wrapper or a future dangerous construct.
# Reject (defense in depth) — if a legitimate case appears, add the
# specific subclass to the explicit-skip list above (Connector, etc.).
if not name:
raise WhereValidationError(
REJECT_UNKNOWN_FUNCTION,
f"unrecognized function-like node: {type(func).__name__}",
detail={"function": type(func).__name__},
)
if name not in ALLOWED_FUNCTIONS:
raise WhereValidationError(
REJECT_UNKNOWN_FUNCTION,
f"function not in v1 allow-list: {name}",
detail={"function": name},
)