Files
Jackson Harper 6a201018b1 Fix promise
2024-07-05 15:20:45 +08:00

298 lines
9.0 KiB
Python

import logging
from flask import Flask, request, jsonify
from prometheus_client import start_http_server, Histogram, Summary, Counter, generate_latest
from typing import List
from timeit import default_timer as timer
import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
import joblib
from urllib.parse import urlparse
from datetime import datetime
import dateutil.parser
from google.cloud import storage
import concurrent.futures
from threading import Lock, RLock
from collections import ChainMap
import copy
from features.user_history import FEATURE_COLUMNS
from auth import user_token_required, admin_token_required
class ThreadSafeUserFeatures:
def __init__(self):
self._data = {}
self._lock = RLock()
def get(self):
with self._lock:
return dict(self._data)
def update(self, new_features):
with self._lock:
self._data.update(new_features)
app = Flask(__name__)
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
USER_HISTORY_PATH = 'user_features.pkl'
MODEL_PIPELINE_PATH = 'predict_read_model-v003.pkl'
pipeline = None
user_features_store = ThreadSafeUserFeatures()
# these buckets are used for reporting scores, we want to make sure
# there is decent diversity in the returned scores.
score_bucket_ranges = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
score_buckets = {
f'score_bucket_{int(b * 10)}': Counter(f'inference_score_bucket_{int(b * 10)}', f'Number of scores in the range {b - 0.1:.1f} to {b:.1f}')
for b in score_bucket_ranges
}
def observe_score(score):
for b in score_bucket_ranges:
if b - 0.1 < score <= b:
score_buckets[f'score_bucket_{int(b * 10)}'].inc()
break
def download_from_gcs(bucket_name, gcs_path, destination_path):
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(gcs_path)
blob.download_to_filename(destination_path)
def load_pipeline(path):
pipeline = joblib.load(path)
return pipeline
def load_tables_from_pickle(path):
with open(path, 'rb') as handle:
tables = pickle.load(handle)
return tables
def load_user_features(path):
result = {}
tables = load_tables_from_pickle(path)
for table_name in tables.keys():
result[table_name] = tables[table_name].to_pandas()
return result
def dataframe_to_dict(df):
result = {}
for index, row in df.iterrows():
user_id = row['user_id']
if user_id not in result:
result[user_id] = []
result[user_id].append(row.to_dict())
return result
def merge_dicts(dict1, dict2):
for key, value in dict2.items():
if key in dict1:
dict1[key].extend(value)
else:
dict1[key] = value
return dict1
def predict_proba_wrapper(X):
return pipeline.predict_proba(X)
def refresh_data():
start = timer()
global pipeline
if os.getenv('LOAD_LOCAL_MODEL') == None:
app.logger.info(f"loading data from {os.getenv('GCS_BUCKET')}")
gcs_bucket_name = os.getenv('GCS_BUCKET')
download_from_gcs(gcs_bucket_name, f'data/features/{USER_HISTORY_PATH}', USER_HISTORY_PATH)
download_from_gcs(gcs_bucket_name, f'data/models/{MODEL_PIPELINE_PATH}', MODEL_PIPELINE_PATH)
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
new_features = load_user_features(USER_HISTORY_PATH)
user_features_store.update(new_features)
app.logger.info(f'time to refresh data (in seconds): {timer() - start}')
app.logger.info(f'loaded pipeline: {pipeline}')
app.logger.info(f'loaded number of user_features: {len(new_features)}')
def compute_score(user_id, item_features, user_features):
interaction_score = compute_interaction_score(user_id, item_features, user_features)
observe_score(interaction_score)
return {
'score': interaction_score,
'interaction_score': interaction_score,
}
def compute_interaction_score(user_id, item_features, user_features):
start = timer()
original_url_host = urlparse(item_features.get('original_url')).netloc
df_test = pd.DataFrame([{
'user_id': user_id,
'author': item_features.get('author'),
'site': item_features.get('site'),
'subscription': item_features.get('subscription'),
'original_url_host': original_url_host,
'item_has_thumbnail': 1 if item_features.get('has_thumbnail') else 0,
"item_has_site_icon": 1 if item_features.get('has_site_icon') else 0,
'item_word_count': item_features.get('words_count'),
'is_subscription': 1 if item_features.get('is_subscription') else 0,
'is_newsletter': 1 if item_features.get('is_newsletter') else 0,
'is_feed': 1 if item_features.get('is_feed') else 0,
'days_since_subscribed': item_features.get('days_since_subscribed'),
'subscription_count': item_features.get('subscription_count'),
'subscription_auto_add_to_library': item_features.get('subscription_auto_add_to_library'),
'subscription_fetch_content': item_features.get('subscription_fetch_content'),
'has_author': 1 if item_features.get('author') else 0,
'inbox_folder': 1 if item_features.get('folder') == 'inbox' else 0,
}])
for name, df in user_features.items():
df = df[df['user_id'] == user_id]
if 'author' in name:
merge_keys = ['user_id', 'author']
elif 'site' in name:
merge_keys = ['user_id', 'site']
elif 'subscription' in name:
merge_keys = ['user_id', 'subscription']
elif 'original_url_host' in name:
merge_keys = ['user_id', 'original_url_host']
else:
print("skipping feature: ", name)
continue
df_test = pd.merge(df_test, df, on=merge_keys, how='left')
df_test = df_test.fillna(0)
df_predict = df_test[FEATURE_COLUMNS]
infer_start = timer()
interaction_score = pipeline.predict_proba(df_predict)
app.logger.info(f'time to call infer (in seconds): {timer() - infer_start}')
app.logger.info(f'INTERACTION SCORE: {interaction_score}')
app.logger.info(f'item_features:\n{df_predict[df_predict != 0].stack()}')
app.logger.info(f'time to compute score (in seconds): {timer() - start}')
return np.float64(interaction_score[0][1])
def process_parallel_item(user_id, key, item, user_features):
library_item_id = item['library_item_id']
return library_item_id, compute_score(user_id, item, user_features)
def parallel_compute_scores(user_id, items, max_workers=None):
user_features = user_features_store.get()
result = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_item = {executor.submit(process_parallel_item, user_id, key, item, user_features): (key, item)
for key, item in items.items()}
for future in concurrent.futures.as_completed(future_to_item):
key, item = future_to_item[future]
try:
library_item_id, score = future.result()
result[library_item_id] = score
except Exception as exc:
app.logger.error(f'Item {key} generated an exception: {exc}')
return result
@app.route('/_ah/health', methods=['GET'])
def ready():
return jsonify({'OK': 'yes'}), 200
@app.route('/metrics')
def metrics():
return generate_latest(), 200, {'Content-Type': 'text/plain; charset=utf-8'}
@app.route('/refresh', methods=['GET'])
@admin_token_required
def refresh():
refresh_data()
return jsonify({'OK': 'yes'}), 200
@app.route('/users/<user_id>/features', methods=['GET'])
@admin_token_required
def get_user_features(user_id):
result = {}
df_user = pd.DataFrame([{
'user_id': user_id,
}])
user_features = user_features_store.get()
user_data = {}
for name, df in user_features.items():
df = df[df['user_id'] == user_id]
df_dict = dataframe_to_dict(df)
user_data = merge_dicts(user_data, df_dict)
return jsonify(user_data), 200
@app.route('/predict', methods=['POST'])
@user_token_required
def predict():
try:
data = request.get_json()
app.logger.info(f"predict scoring request: {data}")
user_id = request.user_id
item_features = data.get('item_features')
if user_id is None:
return jsonify({'error': 'Missing user_id'}), 400
user_features = user_features_store.get()
score = compute_score(user_id, item_features, user_features)
return jsonify({'score': score})
except Exception as e:
app.logger.error(f"exception in predict endpoint: {request.get_json()}\n{e}")
return jsonify({'error': str(e)}), 500
@app.route('/batch', methods=['POST'])
@user_token_required
def batch():
start = timer()
try:
data = request.get_json()
items = data.get('items')
user_id = request.user_id
if user_id == None:
return jsonify({'error': 'no user_id supplied'}), 400
if len(items) > 101:
return jsonify({'error': f'too many items: {len(items)}'}), 400
result = parallel_compute_scores(user_id, items)
app.logger.info(f'time to compute batch of {len(items)} items (in seconds): {timer() - start}')
return jsonify(result)
except Exception as e:
app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}")
return jsonify({'error': str(e)}), 500
## Run this on startup in both production and development modes
refresh_data()
if __name__ == '__main__':
app.run(debug=True, port=5000)