Replaces the BigQuery wrap-view pattern with a discovery + scoped-fetch toolkit driven by the analyst's Claude session. Adds /api/v2/{catalog,schema,sample,scan,scan/estimate}, da catalog/schema/describe/fetch/snapshot/disk-info CLI commands, sqlglot-backed WHERE validator, process-local quota tracker, agent rails skill (cli/skills/agnes-data-querying.md). BREAKING: BQ wrap views off by default — set data_source.bigquery.legacy_wrap_views=true for one cycle. Backward-compat field_validator on primary_key. Catalog cache now matches documented 300s TTL with RBAC fresh per request. Cuts release v0.14.0.
244 lines
9.7 KiB
Python
244 lines
9.7 KiB
Python
"""WHERE clause validator for /api/v2/scan.
|
|
|
|
Single security perimeter — every analyst-supplied predicate flows through here
|
|
before reaching BigQuery. Allow-list-driven; explicit rejection codes per spec §3.7.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Mapping
|
|
|
|
import sqlglot
|
|
from sqlglot import exp
|
|
from sqlglot.errors import ParseError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Rejection kind codes (stable; used by callers + tests + audit log)
|
|
REJECT_PARSE = "parse_error"
|
|
REJECT_NESTED_SELECT = "nested_select"
|
|
REJECT_MULTI_STATEMENT = "multi_statement"
|
|
REJECT_DDL_DML = "ddl_or_dml"
|
|
REJECT_CROSS_TABLE = "cross_table_reference"
|
|
REJECT_UNKNOWN_FUNCTION = "unknown_function"
|
|
REJECT_UNKNOWN_COLUMN = "unknown_column"
|
|
REJECT_DISALLOWED_NODE = "disallowed_node"
|
|
|
|
|
|
@dataclass
|
|
class WhereValidationError(Exception):
|
|
kind: str
|
|
message: str
|
|
detail: dict | None = None
|
|
|
|
def __str__(self) -> str:
|
|
return f"[{self.kind}] {self.message}"
|
|
|
|
|
|
# Nodes that imply DDL/DML (rejected outright).
|
|
_DDL_DML_NODES = (
|
|
exp.Insert, exp.Update, exp.Delete, exp.Drop, exp.TruncateTable,
|
|
exp.Alter, exp.Create, exp.Copy, exp.Merge,
|
|
)
|
|
|
|
|
|
# v1 BigQuery function allow-list (spec §3.7). Stored as upper-case names.
|
|
_ALLOW_FUNCTIONS_DATETIME = {
|
|
"CURRENT_DATE", "CURRENT_TIMESTAMP", "CURRENT_TIME",
|
|
"DATE", "DATETIME", "TIMESTAMP", "TIME",
|
|
"DATE_ADD", "DATE_SUB", "DATE_DIFF", "DATE_TRUNC", "EXTRACT",
|
|
"FORMAT_DATE", "FORMAT_TIMESTAMP", "PARSE_DATE", "PARSE_TIMESTAMP",
|
|
"UNIX_SECONDS", "UNIX_MILLIS",
|
|
}
|
|
_ALLOW_FUNCTIONS_STRING = {
|
|
"CONCAT", "LENGTH", "LOWER", "UPPER", "SUBSTR", "SUBSTRING",
|
|
"TRIM", "LTRIM", "RTRIM", "REPLACE",
|
|
"STARTS_WITH", "ENDS_WITH", "CONTAINS_SUBSTR",
|
|
"REGEXP_CONTAINS", "REGEXP_EXTRACT", "SAFE_CAST",
|
|
# sqlglot normalizes some BQ funcs to a canonical SQL name; allow both spellings.
|
|
"REGEXP_LIKE", # sqlglot canonical for REGEXP_CONTAINS
|
|
}
|
|
_ALLOW_FUNCTIONS_MATH = {
|
|
"ABS", "CEIL", "FLOOR", "ROUND", "MOD", "POWER", "SQRT",
|
|
"LOG", "LN", "EXP", "SIGN", "GREATEST", "LEAST",
|
|
}
|
|
_ALLOW_FUNCTIONS_CAST = {"CAST"}
|
|
_ALLOW_FUNCTIONS_CONDITIONAL = {"IF", "IFNULL", "COALESCE", "NULLIF", "CASE"}
|
|
|
|
ALLOWED_FUNCTIONS: frozenset[str] = frozenset(
|
|
_ALLOW_FUNCTIONS_DATETIME
|
|
| _ALLOW_FUNCTIONS_STRING
|
|
| _ALLOW_FUNCTIONS_MATH
|
|
| _ALLOW_FUNCTIONS_CAST
|
|
| _ALLOW_FUNCTIONS_CONDITIONAL
|
|
)
|
|
|
|
|
|
def validate_where(
|
|
predicate: str,
|
|
table_id: str,
|
|
schema: Mapping[str, str],
|
|
*,
|
|
dialect: str = "bigquery",
|
|
) -> exp.Expression:
|
|
"""Validate a WHERE-clause fragment.
|
|
|
|
Args:
|
|
predicate: SQL fragment (without leading 'WHERE').
|
|
table_id: target table id; cross-table references rejected.
|
|
schema: {column_name: type} for the target table.
|
|
dialect: sqlglot dialect to parse with. Default 'bigquery'. Pass 'duckdb'
|
|
(or anything sqlglot supports) when the predicate will be executed
|
|
against a local DuckDB scan, so DuckDB-specific syntax parses.
|
|
|
|
Returns:
|
|
Parsed sqlglot expression tree (caller may re-stringify or inspect).
|
|
|
|
Raises:
|
|
WhereValidationError: with .kind set to one of the REJECT_* codes.
|
|
"""
|
|
if not predicate or not predicate.strip():
|
|
raise WhereValidationError(REJECT_PARSE, "empty predicate")
|
|
|
|
# Multi-statement detection: BQ statements separated by ';' would parse
|
|
# as multiple expressions in sqlglot.parse() (returns a list).
|
|
try:
|
|
statements = sqlglot.parse(f"SELECT 1 FROM t WHERE {predicate}", dialect=dialect)
|
|
except ParseError as e:
|
|
raise WhereValidationError(REJECT_PARSE, f"parse failed: {e}")
|
|
|
|
if statements is None or len(statements) != 1 or statements[0] is None:
|
|
raise WhereValidationError(REJECT_MULTI_STATEMENT, "multi-statement input not allowed")
|
|
|
|
select = statements[0]
|
|
# A predicate like `1=1 UNION ALL SELECT secret FROM x` parses as a single
|
|
# `exp.Union` (not `exp.Select`), and `find(exp.Where)` would return only
|
|
# the left side's `1=1` — passing structural checks while the raw predicate
|
|
# string still gets concatenated into the final SQL. Reject here.
|
|
if not isinstance(select, exp.Select):
|
|
raise WhereValidationError(
|
|
REJECT_DISALLOWED_NODE,
|
|
f"top-level statement must be SELECT, got {type(select).__name__}",
|
|
)
|
|
where = select.find(exp.Where)
|
|
if where is None:
|
|
raise WhereValidationError(REJECT_PARSE, "no WHERE expression found in parsed input")
|
|
|
|
_walk_structural(where, table_id, schema)
|
|
return where
|
|
|
|
|
|
def safe_where_predicate(
|
|
predicate: str,
|
|
table_id: str,
|
|
schema: Mapping[str, str],
|
|
*,
|
|
dialect: str = "bigquery",
|
|
) -> str:
|
|
"""Validate `predicate` and return the canonical comment-stripped SQL fragment.
|
|
|
|
Use this everywhere a validated predicate is concatenated into final SQL.
|
|
Splicing the raw input string is unsafe: sqlglot strips ``-- ...`` line
|
|
comments and ``/* ... */`` block comments during parsing, so a predicate
|
|
like ``1=1 --`` validates clean (AST sees just ``1=1``) but the raw
|
|
string commented out everything after it in the final SQL — bypassing
|
|
LIMIT, ORDER BY, and the server-enforced row caps.
|
|
"""
|
|
where = validate_where(predicate, table_id, schema, dialect=dialect)
|
|
# `where.this` is the expression inside WHERE (without the WHERE keyword).
|
|
# `.sql(comments=False)` re-renders the parsed AST and explicitly drops
|
|
# `/* ... */` block comments that sqlglot otherwise attaches to nodes
|
|
# as metadata (line `--` comments are already lost during parsing).
|
|
return where.this.sql(dialect=dialect, comments=False)
|
|
|
|
|
|
def _walk_structural(node: exp.Expression, table_id: str, schema: Mapping[str, str]) -> None:
|
|
"""Walk the WHERE AST and reject disallowed structures."""
|
|
for sub in node.walk():
|
|
# `node.walk()` yields the node itself first; check structural rules.
|
|
if isinstance(sub, exp.Subquery) or (isinstance(sub, exp.Select) and sub is not node):
|
|
raise WhereValidationError(REJECT_NESTED_SELECT, "nested SELECT/subquery not allowed")
|
|
if isinstance(sub, _DDL_DML_NODES):
|
|
raise WhereValidationError(REJECT_DDL_DML, f"DDL/DML node {type(sub).__name__} not allowed")
|
|
|
|
# Cross-table reference detection: any column with a qualifier other than
|
|
# the target table_id (or unqualified) is rejected.
|
|
for col in node.find_all(exp.Column):
|
|
qualifier = col.table # e.g. "other_table" in `other_table.id`
|
|
if qualifier and qualifier.lower() != table_id.lower():
|
|
raise WhereValidationError(
|
|
REJECT_CROSS_TABLE,
|
|
f"column {col.sql()} references table {qualifier!r}, expected {table_id!r}",
|
|
)
|
|
|
|
_walk_functions(node)
|
|
_walk_columns(node, schema)
|
|
|
|
|
|
def _walk_columns(node: exp.Expression, schema: Mapping[str, str]) -> None:
|
|
"""Reject column references not present in the target table's schema."""
|
|
known = {c.lower() for c in schema}
|
|
for col in node.find_all(exp.Column):
|
|
# `col.name` is the leaf column name (e.g. "country_code" in
|
|
# "tbl.country_code"). For dotted struct fields like "rec.sub.leaf",
|
|
# sqlglot models as nested exp.Dot; v1 only checks top-level names.
|
|
leaf = (col.name or "").lower()
|
|
if leaf and leaf not in known:
|
|
raise WhereValidationError(
|
|
REJECT_UNKNOWN_COLUMN,
|
|
f"column {col.name!r} not in schema for {col.table!r}",
|
|
detail={"column": col.name},
|
|
)
|
|
|
|
|
|
def _walk_functions(node: exp.Expression) -> None:
|
|
"""Reject function calls outside the allow-list.
|
|
|
|
sqlglot represents function calls in two ways:
|
|
- typed subclasses (e.g. ``exp.Length``, ``exp.StartsWith``, ``exp.SessionUser``,
|
|
``exp.Cast``, ``exp.Coalesce``) — canonical SQL name available via ``sql_name()``;
|
|
- ``exp.Anonymous`` for unknown built-ins or UDFs — name in ``func.name``.
|
|
Both paths funnel into ``ALLOWED_FUNCTIONS``; everything else is rejected.
|
|
"""
|
|
for func in node.find_all(exp.Func):
|
|
# Logical connectors (AND/OR/XOR) inherit exp.Func in sqlglot but are
|
|
# operators, not user-callable functions. Skip them.
|
|
if isinstance(func, exp.Connector):
|
|
continue
|
|
|
|
if isinstance(func, exp.AggFunc):
|
|
raise WhereValidationError(
|
|
REJECT_UNKNOWN_FUNCTION,
|
|
f"aggregate function not allowed in WHERE: {type(func).__name__}",
|
|
detail={"function": type(func).__name__.upper()},
|
|
)
|
|
|
|
# `Anonymous` carries the source name in `func.name`; typed nodes carry
|
|
# their canonical SQL name via `sql_name()`. `name` on typed nodes often
|
|
# holds the first child's identifier, so we never trust it directly.
|
|
if isinstance(func, exp.Anonymous):
|
|
name = (func.name or "").upper()
|
|
else:
|
|
try:
|
|
name = (func.sql_name() or "").upper()
|
|
except Exception:
|
|
name = ""
|
|
|
|
# If sql_name() returns empty for a typed Func, we can't tell whether
|
|
# it's a benign operator wrapper or a future dangerous construct.
|
|
# Reject (defense in depth) — if a legitimate case appears, add the
|
|
# specific subclass to the explicit-skip list above (Connector, etc.).
|
|
if not name:
|
|
raise WhereValidationError(
|
|
REJECT_UNKNOWN_FUNCTION,
|
|
f"unrecognized function-like node: {type(func).__name__}",
|
|
detail={"function": type(func).__name__},
|
|
)
|
|
|
|
if name not in ALLOWED_FUNCTIONS:
|
|
raise WhereValidationError(
|
|
REJECT_UNKNOWN_FUNCTION,
|
|
f"function not in v1 allow-list: {name}",
|
|
detail={"function": name},
|
|
)
|