Merge pull request #4113 from omnivore-app/feat/ml-refresh

Add internal function for refreshing features and model
This commit is contained in:
Jackson Harper
2024-06-26 13:45:20 +08:00
committed by GitHub

View File

@ -2,6 +2,7 @@ import logging
from flask import Flask, request, jsonify
from typing import List
from timeit import default_timer as timer
import os
import sys
@ -24,6 +25,8 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
USER_HISTORY_PATH = 'user_features.pkl'
MODEL_PIPELINE_PATH = 'predict_read_pipeline-v002.pkl'
pipeline = None
user_features = None
def download_from_gcs(bucket_name, gcs_path, destination_path):
storage_client = storage.Client()
@ -69,6 +72,21 @@ def merge_dicts(dict1, dict2):
dict1[key] = value
return dict1
def refresh_data():
start = timer()
global pipeline
global user_features
if os.getenv('LOAD_LOCAL_MODEL') != None:
gcs_bucket_name = os.getenv('GCS_BUCKET')
download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH)
download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH)
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
user_features = load_user_features(USER_HISTORY_PATH)
end = timer()
print('time to refresh data (in seconds):', end - start)
print('loaded pipeline:', pipeline)
print('loaded number of user_features:', len(user_features))
def compute_score(user_id, item_features):
interaction_score = compute_interaction_score(user_id, item_features)
@ -132,6 +150,12 @@ def ready():
return jsonify({'OK': 'yes'}), 200
@app.route('/refresh', methods=['GET'])
def refresh():
refresh_data()
return jsonify({'OK': 'yes'}), 200
@app.route('/users/<user_id>/features', methods=['GET'])
def get_user_features(user_id):
result = {}
@ -191,16 +215,6 @@ def batch():
app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}")
return jsonify({'error': str(e)}), 500
if os.getenv('LOAD_LOCAL_MODEL') != None:
gcs_bucket_name = os.getenv('GCS_BUCKET')
download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH)
download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH)
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
user_features = load_user_features(USER_HISTORY_PATH)
print('loaded pipeline and user_features', pipeline, user_features)
if __name__ == '__main__':
refresh_data()
app.run(debug=True, port=5000)