Add internal function for refreshing features and model
This commit is contained in:
@ -2,6 +2,7 @@ import logging
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
from typing import List
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import os
|
||||
import sys
|
||||
@ -24,6 +25,8 @@ logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
USER_HISTORY_PATH = 'user_features.pkl'
|
||||
MODEL_PIPELINE_PATH = 'predict_read_pipeline-v002.pkl'
|
||||
|
||||
pipeline = None
|
||||
user_features = None
|
||||
|
||||
def download_from_gcs(bucket_name, gcs_path, destination_path):
|
||||
storage_client = storage.Client()
|
||||
@ -69,6 +72,21 @@ def merge_dicts(dict1, dict2):
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
||||
def refresh_data():
|
||||
start = timer()
|
||||
global pipeline
|
||||
global user_features
|
||||
if os.getenv('LOAD_LOCAL_MODEL') != None:
|
||||
gcs_bucket_name = os.getenv('GCS_BUCKET')
|
||||
download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH)
|
||||
download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH)
|
||||
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
|
||||
user_features = load_user_features(USER_HISTORY_PATH)
|
||||
end = timer()
|
||||
print('time to refresh data (in seconds):', end - start)
|
||||
print('loaded pipeline:', pipeline)
|
||||
print('loaded number of user_features:', len(user_features))
|
||||
|
||||
|
||||
def compute_score(user_id, item_features):
|
||||
interaction_score = compute_interaction_score(user_id, item_features)
|
||||
@ -132,6 +150,12 @@ def ready():
|
||||
return jsonify({'OK': 'yes'}), 200
|
||||
|
||||
|
||||
@app.route('/refresh', methods=['GET'])
|
||||
def refresh():
|
||||
refresh_data()
|
||||
return jsonify({'OK': 'yes'}), 200
|
||||
|
||||
|
||||
@app.route('/users/<user_id>/features', methods=['GET'])
|
||||
def get_user_features(user_id):
|
||||
result = {}
|
||||
@ -191,16 +215,6 @@ def batch():
|
||||
app.logger.error(f"exception in batch endpoint: {request.get_json()}\n{e}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
if os.getenv('LOAD_LOCAL_MODEL') != None:
|
||||
gcs_bucket_name = os.getenv('GCS_BUCKET')
|
||||
download_from_gcs(gcs_bucket_name, f'data/features/user_features.pkl', USER_HISTORY_PATH)
|
||||
download_from_gcs(gcs_bucket_name, f'data/models/predict_read_pipeline-v002.pkl', MODEL_PIPELINE_PATH)
|
||||
|
||||
|
||||
pipeline = load_pipeline(MODEL_PIPELINE_PATH)
|
||||
user_features = load_user_features(USER_HISTORY_PATH)
|
||||
print('loaded pipeline and user_features', pipeline, user_features)
|
||||
|
||||
if __name__ == '__main__':
|
||||
refresh_data()
|
||||
app.run(debug=True, port=5000)
|
||||
Reference in New Issue
Block a user