构建支持在线推理与离线分析的混合特征存储架构


机器学习系统在生产环境中面临一个根本性的矛盾:模型训练与离线分析需要对海量历史数据进行灵活、复杂的批处理查询,而在线推理服务则要求对单点数据进行毫秒级的低延迟查找。试图用单一存储系统满足这两种截然不同的负载模式,通常会导致架构上的妥协和性能瓶颈。传统的做法是维护两套独立的系统——一个用于在线服务的键值存储(如Redis),一个用于离线分析的数据仓库(如Hive),并通过复杂的ETL管道同步数据,这带来了数据一致性延迟、存储成本翻倍和运维复杂度剧增的问题。

本文将探讨一种混合架构,旨在统一在线与离线场景,同时满足两种负载的需求。该架构的核心挑战在于如何构建一个数据平面,使其既能支持AI模型(PyTorch)进行亚十毫秒级的特征获取,又能为数据科学家提供一个强大的SQL接口(Trino),对同一份数据进行探索性分析和特征工程。

架构决策:两种失败的备选方案

在确定最终架构前,我们评估并否决了两种看似可行但存在严重缺陷的方案。

方案A:万能数据库方案 (PostgreSQL / MySQL)

一种直接的想法是使用功能强大的关系型数据库来同时处理在线和离线请求。

  • 优势:
    • 技术栈统一,易于维护。
    • 强大的SQL支持和事务能力,数据一致性有保障。
  • 劣势:
    • 水平扩展性不足: 面对每日数亿甚至数十亿的特征写入请求,传统关系型数据库的分片和扩展能力非常有限,很快会成为写入瓶颈。
    • 在线读取延迟: 虽然对主键的查询性能尚可,但在高并发下,B-Tree索引的锁竞争和磁盘IO开销难以稳定控制在个位数毫秒。
    • 分析查询性能差: 对于需要全表扫描或复杂聚合的分析型查询,行式存储的关系型数据库性能极差,会严重影响在线服务的稳定性。

在真实项目中,这种方案仅适用于早期原型或数据量极小的场景。一旦流量增长,它将是整个系统中最先崩溃的一环。

方案B:彻底的读写分离与数据复制 (ETL)

这是业界较为常见的解决方案:使用一套系统处理在线流量,另一套处理离线分析,中间通过ETL同步。

graph TD
    subgraph Online Path
        IngestService --> Redis(Redis for Serving);
        PyTorchService -- Read Feature --> Redis;
    end

    subgraph Offline Path
        ETL_Job[Spark/Flink ETL] -- Batch Copy --> DataWarehouse(Hive/S3);
        DataScientist -- SQL --> DataWarehouse;
    end

    IngestService -- Raw Events --> Kafka;
    Kafka --> ETL_Job;
  • 优势:
    • 在线和离线负载物理隔离,互不影响。
    • 每个组件都用于其最擅长的领域(Redis负责低延迟,Hive负责吞吐量)。
  • 劣势:
    • 数据一致性延迟: ETL过程通常是小时级甚至天级,导致在线系统看到的数据与离线系统分析的数据存在巨大鸿沟。这对于需要近实时反馈的欺诈检测或推荐系统是致命的。
    • 存储与计算成本翻倍: 同一份数据至少存储两份。ETL过程本身也消耗大量计算资源。
    • 运维复杂性: 维护Kafka、ETL作业、在线存储、离线仓库这一整条链路,需要专门的数据工程团队,排查问题非常困难。一个常见的错误是,修复了在线数据的问题,却忘记在ETL逻辑中同步修复,导致数据不一致性随时间累积。

最终架构:基于Cassandra和Trino的混合数据平面

我们的目标是找到一个方案,既能避免方案A的性能瓶颈,又能解决方案B的数据割裂问题。最终选定的架构利用了Cassandra的分布式特性和Trino的联邦查询能力。

graph TD
    A[API Gateway] --> B(Feature Ingestion Service);
    A --> C(PyTorch Inference Service);
    A --> D(Trino Coordinator);

    B -- Write Features --> E[(Cassandra Cluster)];
    C -- Direct Read (Low Latency) --> E;

    subgraph Offline Analysis
        D -- Federated Query --> E;
        D -- Federated Query --> F[S3 Data Lake];
        G[Data Scientist/Analyst] -- SQL Client --> D;
    end

    subgraph Online Serving
        H[End User] -- Request --> A;
        C -- Loads Model --> I((PyTorch Model));
    end

这个架构的核心思想是:

  1. 统一的写入路径: 所有特征数据,无论是实时事件流还是批量计算结果,都通过API网关写入一个高可用的分布式数据库——Cassandra。
  2. 双重读取路径:
    • 在线路径: PyTorch推理服务为了追求极致的低延迟,通过原生的Cassandra驱动直接连接数据库,执行基于主键的点查操作。
    • 离线路径: 数据科学家通过Trino连接器访问Cassandra。Trino将Cassandra表暴露为标准SQL表,允许进行复杂的聚合、过滤和连接操作。更关键的是,Trino可以同时连接其他数据源(如存储历史归档的S3),实现跨数据源的联邦查询。

这样,我们用同一份数据(至少是热数据)服务于两种场景,从根本上解决了数据延迟和不一致问题。

技术选型理由

  • Cassandra:

    • 写入性能: 其LSM-Tree存储引擎对写入操作极其友好,能够轻松支撑高吞吐量的数据摄入。
    • 水平扩展与高可用: 无主(Masterless)架构,增加节点即可线性扩展性能和容量。数据通过副本策略分布在多个节点,任何单点故障都不会影响服务。
    • 低延迟点查: 只要查询严格基于分区键,Cassandra的读取延迟可以稳定在个位数毫秒内,完美满足在线推理的需求。
  • Trino (Presto):

    • 联邦查询引擎: Trino本身不存储数据,而是通过其丰富的连接器(Connector)查询底层系统。这是连接在线存储(Cassandra)和离线存储(S3)的桥梁。
    • 强大的SQL兼容性: 为NoSQL数据库Cassandra提供了完整的ANSI SQL能力,极大地降低了数据分析的门槛。
    • 内存计算: Trino是纯内存计算引擎,对于中等规模的交互式查询响应速度远超MapReduce模型。
  • API Gateway:

    • 统一入口: 为特征写入、模型推理、离线查询提供统一的认证、鉴权、限流和路由。
    • 解耦: 后端服务可以独立迭代和部署,网关负责暴露稳定的API。
  • PyTorch Service:

    • 生产级推理: 使用TorchScript将模型序列化,可以在不依赖Python解释器全局锁(GIL)的环境中高效运行。

核心实现细节

1. Cassandra数据建模

在Cassandra中,数据模型的设计直接决定了查询性能。一切设计都必须围绕查询模式展开。假设我们要存储用户画像特征,在线查询总是通过user_id进行。

-- DDL for Cassandra
CREATE KEYSPACE feature_store WITH replication = {
  'class': 'NetworkTopologyStrategy',
  'datacenter1': '3'
} AND durable_writes = true;

USE feature_store;

CREATE TABLE user_features (
    user_id text,
    feature_name text,
    feature_value double,
    last_updated timestamp,
    PRIMARY KEY (user_id, feature_name)
) WITH CLUSTERING ORDER BY (feature_name ASC);
  • 分区键 (user_id): 这是最重要的设计。所有关于一个用户的数据都会被存储在同一个物理分区(由同一组节点副本管理)。在线推理服务通过user_id查询时,Cassandra可以直接定位到数据所在节点,避免了全集群扫描。
  • 聚类键 (feature_name): 在一个分区内,数据会按照feature_name排序。这使得我们可以高效地获取一个用户的所有特征或特定范围的特征。
  • 反范式设计: 注意到这是一个“窄表”设计。在真实项目中,为了减少一次查询的请求次数,我们通常会采用“宽表”的反范式设计,将一个用户的所有特征存储在一行中。
-- Alternative Wide-Table Schema
CREATE TABLE user_features_wide (
    user_id text PRIMARY KEY,
    feature_a double,
    feature_b double,
    feature_c_vector list<float>,
    last_updated map<text, timestamp>
);

这里的权衡是:宽表减少了读取次数,但更新单个特征需要读取整行再写回(Read-Modify-Write),增加了写入的复杂度。选择哪种模型取决于读写比例和业务需求。

2. Trino配置与联邦查询

要让Trino能够查询Cassandra,只需配置一个简单的连接器文件。

trino/etc/catalog/cassandra.properties

# Connector name, will be used as the catalog name in SQL queries
connector.name=cassandra

# Comma-separated list of contact points for the Cassandra cluster
cassandra.contact-points=cassandra-node1.example.com,cassandra-node2.example.com

# Native protocol port
cassandra.native-protocol-port=9042

# Allow Trino to infer schema from Cassandra tables
cassandra.allow-drop-table=false

# Consistency level for Trino queries, ANALYTICS workloads can tolerate lower consistency
cassandra.consistency-level=LOCAL_QUORUM

# Credentials if required
# cassandra.username=...
# cassandra.password=...

配置完成后,数据科学家就可以像查询普通数据库一样查询Cassandra:

-- A typical analytical query via Trino
SELECT
    feature_name,
    AVG(feature_value) AS avg_value,
    COUNT(DISTINCT user_id) AS distinct_users
FROM
    cassandra.feature_store.user_features
WHERE
    last_updated > NOW() - INTERVAL '7' DAY
GROUP BY
    feature_name
ORDER BY
    distinct_users DESC
LIMIT 100;

这个查询在Cassandra原生CQL中是无法或极难高效实现的,但Trino的计算引擎可以轻松处理。

3. PyTorch在线推理服务

这是对延迟最敏感的部分。服务必须使用高效的异步框架和原生驱动来与Cassandra通信。

inference_server.py

import asyncio
import logging
import os
from contextlib import asynccontextmanager

import torch
import uvicorn
from cassandra.cluster import Cluster, Session
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import SimpleStatement
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

# --- Configuration ---
# In a real app, use Pydantic's BaseSettings to load from env vars
CASSANDRA_HOSTS = os.getenv("CASSANDRA_HOSTS", "127.0.0.1").split(",")
CASSANDRA_PORT = int(os.getenv("CASSANDRA_PORT", 9042))
CASSANDRA_KEYSPACE = "feature_store"
MODEL_PATH = "model.pt"

# --- Cassandra Connection Management ---
# A global session object managed by the application lifecycle
cassandra_session: Session | None = None

def get_cassandra_session():
    """Provides a singleton Cassandra session."""
    global cassandra_session
    if cassandra_session is None or cassandra_session.is_shutdown:
        # Production-grade load balancing policy
        load_balancing_policy = TokenAwarePolicy(DCAwareRoundRobinPolicy(local_dc="datacenter1"))
        cluster = Cluster(
            CASSANDRA_HOSTS,
            port=CASSANDRA_PORT,
            load_balancing_policy=load_balancing_policy,
            protocol_version=4
        )
        cassandra_session = cluster.connect(CASSANDRA_KEYSPACE)
    return cassandra_session

def shutdown_cassandra_session():
    """Gracefully shuts down the Cassandra session and cluster connection."""
    global cassandra_session
    if cassandra_session and not cassandra_session.is_shutdown:
        logging.info("Shutting down Cassandra session...")
        cassandra_session.cluster.shutdown()
        cassandra_session = None

# --- Application Lifecycle (Lifespan for FastAPI) ---
ml_models = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logging.info("Application starting up...")
    get_cassandra_session() # Initialize connection pool on startup
    logging.info("Cassandra session initialized.")
    # Load the TorchScript model
    # Using torch.jit.load for production environments is a best practice
    ml_models["my_model"] = torch.jit.load(MODEL_PATH)
    ml_models["my_model"].eval() # Set to evaluation mode
    logging.info(f"Model '{MODEL_PATH}' loaded.")
    yield
    # Shutdown
    logging.info("Application shutting down...")
    shutdown_cassandra_session()
    ml_models.clear()
    logging.info("Resources cleaned up.")

app = FastAPI(lifespan=lifespan)

# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- API Models ---
class InferenceRequest(BaseModel):
    user_id: str = Field(..., description="The user ID to fetch features for.")
    # Other context features could be passed here
    # request_context: dict

class InferenceResponse(BaseModel):
    prediction: float
    model_version: str = "v1.0.0"

# --- Core Inference Logic ---
@app.post("/inference", response_model=InferenceResponse)
async def predict(request: InferenceRequest):
    """
    Fetches user features from Cassandra, runs model inference, and returns the prediction.
    """
    session = get_cassandra_session()
    
    # Using async execute for non-blocking IO is crucial for performance
    # This requires a library like 'cassandra-driver-async' or running in a thread pool
    # For simplicity, we use the sync driver's execute_async which returns a future
    # and we 'await' it in a thread-safe way using asyncio.to_thread in a real async app.
    # Here, we'll demonstrate the synchronous path for clarity. A fully async
    # implementation would use an async-native driver.
    
    query = SimpleStatement("SELECT feature_name, feature_value FROM user_features WHERE user_id = %s")
    
    try:
        # In a real high-performance app, you would use execute_async
        # and manage the futures to run queries concurrently.
        rows = session.execute(query, (request.user_id,))
        
        features = {row.feature_name: row.feature_value for row in rows}
        if not features:
            raise HTTPException(status_code=404, detail=f"User '{request.user_id}' not found.")

        # --- Feature Transformation & Tensor Creation ---
        # The order of features must match the model's training order.
        # A feature registry or at least a fixed list is essential here.
        feature_order = ["feature_a", "feature_b", "feature_c"] # Example order
        
        feature_vector = [features.get(name, 0.0) for name in feature_order] # Use default for missing features
        
        input_tensor = torch.tensor([feature_vector], dtype=torch.float32)

        # --- Model Inference ---
        with torch.no_grad(): # Disable gradient calculation for efficiency
            prediction_tensor = ml_models["my_model"](input_tensor)
        
        prediction_value = prediction_tensor.item()
        
        return InferenceResponse(prediction=prediction_value)

    except Exception as e:
        logging.error(f"Error during inference for user {request.user_id}: {e}")
        # A common mistake is to expose internal errors. Return a generic message.
        raise HTTPException(status_code=500, detail="Internal server error during inference.")

if __name__ == "__main__":
    # Unit testing considerations:
    # 1. Mock the get_cassandra_session() to return a mock session object.
    # 2. Test the feature transformation logic independently.
    # 3. Test for edge cases: missing user, missing features, malformed data.
    uvicorn.run(app, host="0.0.0.0", port=8000)

这段代码包含了生产级服务需要考虑的几个关键点:

  • 连接池管理: get_cassandra_session 创建了一个单例的会话对象,在整个应用生命周期内复用,避免了为每个请求创建连接的巨大开销。
  • 生命周期管理: 使用FastAPI的lifespan事件,在服务启动时预加载模型和初始化数据库连接,在服务关闭时优雅地释放资源。
  • 错误处理: 对数据库查询失败和用户不存在的情况进行了处理,并返回了合适的HTTP状态码。
  • 性能考量: 使用torch.jit.load加载序列化的模型,并在torch.no_grad()上下文中执行推理,这些都是提升性能的标准实践。

4. API Gateway路由配置

最后,API Gateway(以Kong为例)的配置将所有流量串联起来。

kong.yaml (declarative configuration)

_format_version: "2.1"

services:
- name: ingestion-service
  url: http://feature-ingestion-service.internal:8080
- name: inference-service
  url: http://pytorch-inference-service.internal:8000
- name: trino-service
  url: http://trino-coordinator.internal:8080

routes:
- name: ingestion-route
  service: ingestion-service
  paths:
  - /v1/features/ingest
  strip_path: true
  methods: [POST]

- name: inference-route
  service: inference-service
  paths:
  - /v1/models/predict
  strip_path: true
  # Example of a plugin: rate limiting for inference API
  plugins:
  - name: rate-limiting
    config:
      minute: 1000
      policy: local

- name: trino-route
  service: trino-service
  paths:
  - /v1/sql
  strip_path: true
  # This route should be protected by strong authentication (e.g., JWT, mTLS)
  # to prevent unauthorized access to the entire data platform.

这部分配置将外部请求路径映射到内部服务,并可以附加认证、限流、日志等横切关注点,而无需修改后端服务代码。

架构的局限性与未来展望

尽管此架构解决了核心矛盾,但它并非没有缺点。运维一个包含Cassandra、Trino、API Gateway和多个微服务的系统,其复杂性远高于单一数据库方案。Cassandra的调优(如Compaction策略、GC调优)本身就是一个专门的领域。

其次,该方案主要解决了热数据的统一访问问题。对于数年之久的历史冷数据,持续存储在Cassandra中成本高昂。一个自然的演进方向是引入数据分层策略:通过一个定时任务,将超过特定时间(例如3个月)的数据从Cassandra迁移到成本更低的S3对象存储中,并以Parquet等列式格式存储。由于Trino可以同时查询Cassandra和S3,这种迁移对数据分析师是透明的,他们依然可以通过同一个SQL接口查询全量数据。

此外,随着特征数量和复杂度的增加,一个手动的特征管理流程会变得混乱。下一步可以引入一个中心化的特征注册表(Feature Registry),如Feast,来管理特征的定义、元数据和生命周期,确保训练和服务之间使用特征的一致性。


  目录