Please create a request on https://ezq.quantiphi.com/gitlab/user-access if you want to create a new group or project.

inference_pipeline.py 12.1 KB
Newer Older
Pavan Kattamuri's avatar
Pavan Kattamuri committed
1 2 3
import argparse
import json
import logging
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
4
from math import cos, asin, sqrt, pi
Pavan Kattamuri's avatar
Pavan Kattamuri committed
5 6 7 8
import time
import warnings
from datetime import datetime, timedelta
import apache_beam as beam
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
9
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
Pavan Kattamuri's avatar
Pavan Kattamuri committed
10 11
from google.cloud import bigquery
from google.cloud import firestore
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
12 13
from googleapiclient import discovery

Pavan Kattamuri's avatar
Pavan Kattamuri committed
14 15 16
warnings.filterwarnings(action='ignore')


Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
17
def distance(lat1, lon1, lat2, lon2):
Pavan Kattamuri's avatar
Pavan Kattamuri committed
18 19
    """Given two locations with latitude and longitude,
    calculates distance between them"""
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
20 21 22 23
    p = pi/180
    a = 0.5 - cos((lat2-lat1)*p)/2 + cos(lat1*p) * cos(lat2*p) * (1-cos((lon2-lon1)*p))/2
    return 12742 * asin(sqrt(a))

Pavan Kattamuri's avatar
Pavan Kattamuri committed
24
class LookupFirestore(beam.DoFn):
Pavan Kattamuri's avatar
Pavan Kattamuri committed
25 26
    """Lookups firestore and computes aggregate features 
    required for ML model inference"""
Pavan Kattamuri's avatar
Pavan Kattamuri committed
27 28 29 30 31 32 33 34 35 36 37 38 39
    def __init__(self, project_id, firestore_collection):
        self.project_id = project_id
        self.firestore_collection = firestore_collection
        pass

    def start_bundle(self):
        self.db = firestore.Client(project=self.project_id)

    def process(self, elem):
        trans_date_trans_time = datetime.strptime(elem['trans_date_trans_time'], '%Y-%m-%d %H:%M:%S')
        last_day_date_ts = trans_date_trans_time - timedelta(days=1)
        last_week_date_ts = trans_date_trans_time - timedelta(days=7)
        last_month_date_ts = trans_date_trans_time - timedelta(days=30)
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
40 41 42
        # Day adjustment, weekday treat Sunday as 6 and BigQuery treats Sunday as 1
        elem['day'] = (trans_date_trans_time.weekday()+2)%7
        # Age of the customer at the time of transaction
Pavan Kattamuri's avatar
Pavan Kattamuri committed
43
        elem['age'] = round((trans_date_trans_time - datetime.strptime(elem['dob'], '%Y-%m-%d')).days/365,2)
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
44
        # Distance between customer location and merchant location
Pavan Kattamuri's avatar
Pavan Kattamuri committed
45
        elem['distance'] = round(distance(elem['lat'],elem['long'], elem['merch_lat'],elem['merch_long']),2)
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
46
        # Lookup firestore and fetch document for a particular credit card number
Pavan Kattamuri's avatar
Pavan Kattamuri committed
47 48 49 50
        cc_num = elem['cc_num']
        db_ref = self.db.collection(self.firestore_collection)
        query = db_ref.where(u'cc_num', u'==', cc_num)
        docs = query.get()
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
51
        # If Firestore returns zero documents, add the document with the current transaction
Pavan Kattamuri's avatar
Pavan Kattamuri committed
52 53 54 55 56 57
        if not docs:
            print("No document")
            doc_ref = db_ref.add({
                'cc_num': elem['cc_num'],
                'trans_details': [{
                    'amt': elem['amt'],
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
58
                    'trans_date_trans_time': trans_date_trans_time.replace(tzinfo=None)
Pavan Kattamuri's avatar
Pavan Kattamuri committed
59 60 61 62 63 64 65 66
                }]
                
            })
            print(f'Added document with ID{doc_ref[1].id}')
            elem['trans_diff'] = 0
            elem['avg_spend_pw'] = 0
            elem['avg_spend_pm'] = 0
            elem['trans_freq_24'] = 0
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
67
        # Incase of an existing document, update the transaction history
Pavan Kattamuri's avatar
Pavan Kattamuri committed
68 69 70 71 72 73 74
        else:
            for item in docs:
                doc = {
                    'id': item.id,
                    'doc': item.to_dict()
                }
            trans_details = doc['doc']['trans_details']
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
            # if transaction history is available, compute the derived features 
            if trans_details:
                last_trans_time = trans_details[0]['trans_date_trans_time'].replace(tzinfo=None)
                elem['trans_diff'] = int((trans_date_trans_time - last_trans_time).total_seconds()/60)
                
                trans_freq_24 = 0
                last_week_freq = 0
                last_week_spend = 0
                last_month_freq = 0
                last_month_spend = 0
                for i,trans in enumerate(trans_details):
                    if  trans['trans_date_trans_time'].replace(tzinfo=None) >= last_month_date_ts:
                        last_month_freq +=1
                        last_month_spend += trans['amt']
                    else:
                        break
                    if trans['trans_date_trans_time'].replace(tzinfo=None) >= last_week_date_ts:
                        last_week_freq +=1
                        last_week_spend += trans['amt']
                    if  trans['trans_date_trans_time'].replace(tzinfo=None) >= last_day_date_ts:
                        trans_freq_24 +=1
                
                trans_details = trans_details[:i]
                elem['avg_spend_pw'] = last_week_spend/last_week_freq if last_week_freq else 0
                elem['avg_spend_pm'] = last_month_spend/last_month_freq  if last_month_freq else 0
                elem['trans_freq_24'] = trans_freq_24
            # in case of no history, provide default values
            else:
                elem['trans_diff'] = 0
                elem['avg_spend_pw'] = 0
                elem['avg_spend_pm'] = 0
                elem['trans_freq_24'] = 0

            # Insert the current transaction at the starting of the transaction history
            trans_details.insert(0, {'amt':elem['amt'], 'trans_date_trans_time': trans_date_trans_time})
            doc_ref = db_ref.document(doc['id'])
            doc_ref.update({'trans_details': trans_details})
Pavan Kattamuri's avatar
Pavan Kattamuri committed
112 113 114 115

        return[elem]


Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
116
class InvokeMLEndpoint(beam.DoFn):
Pavan Kattamuri's avatar
Pavan Kattamuri committed
117 118
    """Sends two requests to AI platform hosted models, 
    one for model without aggregates and one for model with aggregates"""
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
119 120 121 122
    def __init__(self, project_id, model_name, model_w_agg, model_wo_agg):
        # Full qualifier paths for the models hosted on AI Platform
        self.model_w_agg = f'projects/{project_id}/models/{model_name}/versions/{model_w_agg}'
        self.model_wo_agg = f'projects/{project_id}/models/{model_name}/versions/{model_wo_agg}'
Pavan Kattamuri's avatar
Pavan Kattamuri committed
123 124

    def start_bundle(self):
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
125
        self.service = discovery.build('ml', 'v1', cache_discovery=False)
Pavan Kattamuri's avatar
Pavan Kattamuri committed
126 127

    def process(self, elem):
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
128 129 130 131 132 133 134 135 136 137
        def predict_json(model_full_name, instances):
            response = self.service.projects().predict(
                name=model_full_name,
                body={'instances': instances}
            ).execute()
            if 'error' in response:
                print(response)
                return 'Fail', response['error']

            return 'Success', response['predictions']
Pavan Kattamuri's avatar
Pavan Kattamuri committed
138 139 140

        # Request Instance for model without aggregates
        wo_agg_instance = [{
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
141 142 143 144 145 146 147 148 149
            "category": elem['category'],
            "amt": elem['amt'],
            "state": elem['state'],
            "job": elem['job'],
            "unix_time": elem['unix_time'],
            "city_pop": elem['city_pop'],
            "merchant": elem['merchant'],
            "day": elem['day'],
            "age": elem['age'],
Pavan Kattamuri's avatar
Pavan Kattamuri committed
150
            "distance": elem['distance']
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
151 152
        }]
        # Get prediction and store prediction and confidence scores
Pavan Kattamuri's avatar
Pavan Kattamuri committed
153
        output = predict_json(self.model_wo_agg, wo_agg_instance)
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
154
        if output[0]=='Success':
Pavan Kattamuri's avatar
Pavan Kattamuri committed
155 156 157 158 159
            elem['is_fraud_model_wo_aggregates'] = int(output[1][0]['predicted_is_fraud'])
            elem['prob_is_fraud_model_wo_aggregates'] = [round(output[1][0]['is_fraud_probs'][0],5), round(output[1][0]['is_fraud_probs'][1],5)]

        # Request Instance for model with aggregates
        w_agg_instance = [{
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
160 161 162 163 164 165 166 167 168
            "category": elem['category'],
            "amt": elem['amt'],
            "state": elem['state'],
            "job": elem['job'],
            "unix_time": elem['unix_time'],
            "city_pop": elem['city_pop'],
            "merchant": elem['merchant'],
            "day": elem['day'],
            "age": elem['age'],
Pavan Kattamuri's avatar
Pavan Kattamuri committed
169 170 171 172 173
            "distance": elem['distance'],
            "trans_freq_24": elem['trans_freq_24'],
            "trans_diff": elem['trans_diff'],
            "avg_spend_pw": elem['avg_spend_pm'],
            "avg_spend_pm": elem['avg_spend_pw'],
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
174 175
        }]
        # Get prediction and store prediction and confidence scores
Pavan Kattamuri's avatar
Pavan Kattamuri committed
176
        
Pavan Kattamuri's avatar
Pavan Kattamuri committed
177
        output = predict_json(self.model_w_agg, w_agg_instance)
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
178
        if output[0]=='Success':
Pavan Kattamuri's avatar
Pavan Kattamuri committed
179 180
            elem['is_fraud_model_w_aggregates'] = int(output[1][0]['predicted_is_fraud'])
            elem['prob_is_fraud_model_w_aggregates'] = [round(output[1][0]['is_fraud_probs'][0],5), round(output[1][0]['is_fraud_probs'][1],5)]
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
181

Pavan Kattamuri's avatar
Pavan Kattamuri committed
182 183
        return [elem]

Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
184 185 186 187 188
class FilterFraud(beam.DoFn):
    def __init__(self):
        pass

    def process(self, elem):
Pavan Kattamuri's avatar
Pavan Kattamuri committed
189
        """Filter transactions which are predicted as fraud by model with aggregates or any prioritized model"""
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
190 191 192
        if elem['is_fraud_model_w_aggregates'] and elem['is_fraud_model_w_aggregates']==1:
            return[str(elem).encode('utf-8')]

Pavan Kattamuri's avatar
Pavan Kattamuri committed
193 194

def run(argv=None):
Pavan Kattamuri's avatar
Pavan Kattamuri committed
195
    """Run the workflow."""
Pavan Kattamuri's avatar
Pavan Kattamuri committed
196
    parser = argparse.ArgumentParser()
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
197 198
    parser.add_argument('--firestore-project',
                        dest='firestore_project',
Pavan Kattamuri's avatar
Pavan Kattamuri committed
199
                        required=True,
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
200
                        help='Firestore Project ID')
Pavan Kattamuri's avatar
Pavan Kattamuri committed
201 202 203 204 205 206 207 208 209 210 211
    parser.add_argument('--subscription-name',
                        dest='subscription_name',
                        required=True,
                        help='PubSub subscription name to be consumed')
    parser.add_argument('--firestore-collection',
                        dest='firestore_collection',
                        required=True,
                        help='Collection ID storing transaction details')
    parser.add_argument('--dataset-id',
                        dest='dataset_id',
                        required=True,
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
212
                        help='BigQuery dataset ID containing transaction data')
Pavan Kattamuri's avatar
Pavan Kattamuri committed
213 214 215
    parser.add_argument('--table-name',
                        dest='table_name',
                        required=True,
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
216 217 218 219 220
                        help='BigQuery table name containing transaction data')
    parser.add_argument('--model-name',
                        dest='model_name',
                        required=True,
                        help='AI Platform model name')
Pavan Kattamuri's avatar
Pavan Kattamuri committed
221 222 223
    parser.add_argument('--model-with-aggregates',
                        dest='model_w_agg',
                        required=True,
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
224
                        help='Version name of the model with aggregates')
Pavan Kattamuri's avatar
Pavan Kattamuri committed
225 226 227
    parser.add_argument('--model-without-aggregates',
                        dest='model_wo_agg',
                        required=True,
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
228 229 230 231 232
                        help='Version name of the model without aggregates')
    parser.add_argument('--fraud-notification-topic',
                        dest='fraud_notification_topic',
                        required=True,
                        help='PubSub topic name to receive notification on fraud transactions')
Pavan Kattamuri's avatar
Pavan Kattamuri committed
233

Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
234
    known_args, pipeline_args = parser.parse_known_args(argv)
Pavan Kattamuri's avatar
Pavan Kattamuri committed
235

Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
236
    project_id = known_args.firestore_project
Pavan Kattamuri's avatar
Pavan Kattamuri committed
237 238 239 240
    subscription_name = known_args.subscription_name
    firestore_collection = known_args.firestore_collection
    dataset_id = known_args.dataset_id
    table_name = known_args.table_name
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
241
    model_name = known_args.model_name
Pavan Kattamuri's avatar
Pavan Kattamuri committed
242 243
    model_w_agg = known_args.model_w_agg
    model_wo_agg = known_args.model_wo_agg
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
244
    fraud_notification_topic = known_args.fraud_notification_topic
Pavan Kattamuri's avatar
Pavan Kattamuri committed
245 246 247
    
    subscription_full_path = f'projects/{project_id}/subscriptions/{subscription_name}'

Pavan Kattamuri's avatar
Pavan Kattamuri committed
248 249
    # We use the save_main_session option because one or more DoFn's in this
    # workflow rely on global context (e.g., a module imported at module level).
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
250 251
    pipeline_options = PipelineOptions(pipeline_args)
    pipeline_options.view_as(SetupOptions).save_main_session = True
Pavan Kattamuri's avatar
Pavan Kattamuri committed
252 253 254

    p = beam.Pipeline(options = pipeline_options)

Pavan Kattamuri's avatar
Pavan Kattamuri committed
255
    input_transactions = p | "Read txns from PubSub" >> beam.io.ReadFromPubSub(subscription = subscription_full_path)
Pavan Kattamuri's avatar
Pavan Kattamuri committed
256 257

    processed_data = input_transactions | "Convert to JSON" >> beam.Map(lambda x: json.loads(x)) \
Pavan Kattamuri's avatar
Pavan Kattamuri committed
258
        | "Lookup Historical txns" >> beam.ParDo(LookupFirestore(project_id, firestore_collection))
Pavan Kattamuri's avatar
Pavan Kattamuri committed
259

Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
260
    ml_output = processed_data | "Invoke ML Model" >> beam.ParDo(InvokeMLEndpoint(project_id, model_name, model_w_agg, model_wo_agg))
Pavan Kattamuri's avatar
Pavan Kattamuri committed
261
    ml_output | "Filter fraudulent txns" >> beam.ParDo(FilterFraud()) \
Pavan Kattamuri's avatar
Pavan Kattamuri committed
262 263
        | "Notifications to PubSub" >> beam.io.WriteToPubSub(topic = f'projects/{project_id}/topics/{fraud_notification_topic}')
    ml_output | "Write pred to BigQuery" >> beam.io.WriteToBigQuery(
Pavan Kattamuri's avatar
WIP  
Pavan Kattamuri committed
264
            table=f'{project_id}:{dataset_id}.{table_name}',
Pavan Kattamuri's avatar
Pavan Kattamuri committed
265 266 267 268 269 270 271 272 273
            write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND,
            create_disposition=beam.io.BigQueryDisposition.CREATE_NEVER,
            method="STREAMING_INSERTS")

    result = p.run()


if __name__ == "__main__":
    run()