diff --git a/ml/digest-score/app.py b/ml/digest-score/app.py index f9c428100..0e83dab9d 100644 --- a/ml/digest-score/app.py +++ b/ml/digest-score/app.py @@ -2,6 +2,7 @@ import logging from flask import Flask, request, jsonify from typing import List +from timeit import default_timer as timer import os import sys @@ -24,6 +25,8 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout) USER_HISTORY_PATH = 'user_features.pkl' MODEL_PIPELINE_PATH = 'predict_read_pipeline-v002.pkl' +pipeline = None +user_features = None def download_from_gcs(bucket_name, gcs_path, destination_path): storage_client = storage.Client() @@ -69,6 +72,21 @@ def merge_dicts(dict1, dict2): dict1[key] = value return dict1 +def refresh_data(): + start = timer() + global pipeline + global user_features + if os.getenv('LOAD_LOCAL_MODEL') != None: + gcs_bucket_name = os.getenv('GCS_BUCKET') + download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH) + download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH) + pipeline = load_pipeline(MODEL_PIPELINE_PATH) + user_features = load_user_features(USER_HISTORY_PATH) + end = timer() + print('time to refresh data (in seconds):', end - start) + print('loaded pipeline:', pipeline) + print('loaded number of user_features:', len(user_features)) + def compute_score(user_id, item_features): interaction_score = compute_interaction_score(user_id, item_features) @@ -132,6 +150,12 @@ def ready(): return jsonify({'OK': 'yes'}), 200 +@app.route('/refresh', methods=['GET']) +def refresh(): + refresh_data() + return jsonify({'OK': 'yes'}), 200 + + @app.route('/users//features', methods=['GET']) def get_user_features(user_id): result = {} @@ -191,16 +215,6 @@ def batch(): app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}") return jsonify({'error': str(e)}), 500 - -if os.getenv('LOAD_LOCAL_MODEL') != None: - gcs_bucket_name = os.getenv('GCS_BUCKET') - download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH) - download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH) - - -pipeline = load_pipeline(MODEL_PIPELINE_PATH) -user_features = load_user_features(USER_HISTORY_PATH) -print('loaded pipeline and user_features', pipeline, user_features) - if __name__ == '__main__': + refresh_data() app.run(debug=True, port=5000) \ No newline at end of file